blob: e7704f1629575e59cfaadc7a573a1c828085017a [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01006from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01007from datetime import datetime
8from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07009
Jeremy Johnson1271c442023-09-05 11:39:26 +010010import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010013from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010014from generator.tosa_arg_gen import TosaArgGen
15from generator.tosa_arg_gen import TosaQuantGen
16from generator.tosa_arg_gen import TosaTensorGen
17from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000018from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010019from generator.tosa_error_if import TosaErrorIfArgGen
20from generator.tosa_error_if import TosaErrorValidator
21from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010022from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000023from tosa.DType import DType
24from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010025
Jeremy Johnson1271c442023-09-05 11:39:26 +010026TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
27// SPDX-License-Identifier: Apache-2.0
28// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
29"""
30
Jeremy Johnsonaf090182024-02-13 18:25:39 +000031logging.basicConfig()
32logger = logging.getLogger("tosa_verif_build_tests")
33
Matthew Haddonb724efc2021-08-25 16:40:29 +010034
Eric Kunzee5e26762020-10-13 16:11:07 -070035class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010036 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000037 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010039 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010040 TOSA_8K_LEVEL_MAX_KERNEL = 8192
41 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010042
Jeremy Johnson1271c442023-09-05 11:39:26 +010043 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000044 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010045 TOSA_MI_DOT_PRODUCT_MIN = 1000
46
Eric Kunzee5e26762020-10-13 16:11:07 -070047 def __init__(self, args):
48 self.args = args
49 self.basePath = args.output_dir
50 self.random_seed = args.random_seed
51 self.ser = None
52 self.rng = np.random.default_rng(self.random_seed)
53 self.createDynamicOpLists()
54 self.initOpListDefaults()
55 self.quantGen = TosaQuantGen()
56 # Force makeShape to do a specific starting shape
57 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010058 # JSON schema validation
59 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010060 # Data generator library is sometimes needed for compliance set up
61 # even if we are generating the data later (lazy_data_generation)
62 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070063
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010064 # Work out floating point range
65 def convertFPRange(rangeFP, maxFP):
66 # Converts program arguments of max/-max to FP max
67 vals = []
68 for v in rangeFP:
69 if v == "max":
70 v = maxFP
71 elif v == "-max":
72 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000073 elif v < 0:
74 # Trim to minimum data type value
75 v = max(v, -maxFP)
76 elif v > 0:
77 # Trim to maximum data type value
78 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010079 vals.append(v)
80 return tuple(sorted(vals))
81
82 self.random_float_range = {}
Won Jeon2c34b462024-02-06 18:37:00 +000083 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010084 self.random_float_range[dtype] = convertFPRange(
85 args.tensor_fp_value_range,
86 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
87 )
88
Eric Kunzee5e26762020-10-13 16:11:07 -070089 def createSerializer(self, opName, testPath):
90 self.testPath = os.path.join(opName, testPath)
91
92 fullPath = os.path.join(self.basePath, self.testPath)
93 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010094 # Embed const data in the flatbuffer
95 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 if self.args.lazy_data_gen:
97 # Lazy data generation - so make constants files
98 constMode = ts.ConstMode.INPUTS
99 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100100 constMode = ts.ConstMode.EMBED_DUMP
101 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700102
103 def getSerializer(self):
104 return self.ser
105
Jeremy Johnson1271c442023-09-05 11:39:26 +0100106 def serialize(self, testName, metaData=None):
107 path = Path(self.basePath) / self.testPath
108
109 # Write out TOSA flatbuffer binary
110 path_fb = path / f"{testName}.tosa"
111 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700112 fd.write(self.ser.serialize())
113
Jeremy Johnson1271c442023-09-05 11:39:26 +0100114 # Get JSON descriptor from serializer
115 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
116
117 if metaData:
118 # Add extra meta data to desc.json
119 desc["meta"] = metaData
120
121 # Validate desc.json before we output it
122 self.descSchemaValidator.validate_config(desc)
123
124 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100125 if "data_gen" in metaData:
126 if self.args.lazy_data_gen:
127 # Output datagen meta data as CPP data
128 path_md = path / f"{testName}_meta_data_gen.cpp"
129 with path_md.open("w") as fd:
130 fd.write(TOSA_AUTOGENERATED_HEADER)
131 fd.write("// Test meta data for data generation setup\n\n")
132 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
133 json.dump(metaData["data_gen"], fd)
134 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100135 if "compliance" in metaData:
136 # Output datagen meta data as CPP data
137 path_md = path / f"{testName}_meta_compliance.cpp"
138 with path_md.open("w") as fd:
139 fd.write(TOSA_AUTOGENERATED_HEADER)
140 fd.write("// Test meta data for compliance validation\n\n")
141 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
142 json.dump(metaData["compliance"], fd)
143 fd.write(')";\n\n')
144
145 # Write desc.json
146 path_desc = path / "desc.json"
147 with path_desc.open("w") as fd:
148 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700149
Matthew Haddon74567092021-07-16 15:38:20 +0100150 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000151 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100152 seed = self.random_seed + 1
153 self.rng = np.random.default_rng(seed)
154
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 def getDTypeRange(self, dtype, high_inclusive=False):
156 # Returns dtype value range boundaries (low, high)
157 # The high boundary is excluded in the range
158 # unless high_inclusive is True
Won Jeon2c34b462024-02-06 18:37:00 +0000159 if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100160 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100161 elif dtype == DType.BOOL:
162 rng = (0, 2)
163 elif dtype == DType.UINT8:
164 rng = (0, 256)
165 elif dtype == DType.UINT16:
166 rng = (0, 65536)
167 elif dtype == DType.INT4:
168 # TOSA specific INT4 weight range from -7 to 7
169 rng = (-7, 8)
170 elif dtype == DType.INT8:
171 rng = (-128, 128)
172 elif dtype == DType.INT16:
173 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000174 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100175 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000176 elif dtype == DType.SHAPE:
177 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100178 elif dtype == DType.INT48:
179 rng = (-(1 << 47), (1 << 47))
180 else:
181 raise Exception("Unknown dtype: {}".format(dtype))
182
183 if not high_inclusive:
184 # Exclusive high: low <= range < high
185 return rng
186 else:
187 # Inclusive range: low <= range <= high
188 return (rng[0], rng[1] - 1)
189
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000190 def getRandTensor(self, shape, dtype, data_range=None):
191 if data_range is None:
192 low, high = self.getDTypeRange(dtype)
193 else:
194 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100195
Eric Kunzee5e26762020-10-13 16:11:07 -0700196 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 return np.bool_(self.rng.choice(a=[False, True], size=shape))
evacha011adff832024-03-06 17:33:44 +0000198 elif dtype == DType.INT4:
199 return np.int8(self.rng.integers(low=low, high=high, size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000200 elif dtype == DType.INT8:
201 return np.int8(self.rng.integers(low=low, high=high, size=shape))
202 elif dtype == DType.UINT8:
203 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Jerry Ge20ab3df2024-01-26 16:56:55 +0000204 elif dtype == DType.INT16:
205 return np.int16(self.rng.integers(low=low, high=high, size=shape))
206 elif dtype == DType.UINT16:
207 return np.uint16(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000208 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100209 return np.int64(self.rng.integers(low=low, high=high, size=shape))
Won Jeon2c34b462024-02-06 18:37:00 +0000210 elif dtype in (
211 DType.FP16,
212 DType.BF16,
213 DType.FP32,
214 DType.FP8E4M3,
215 DType.FP8E5M2,
216 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100217 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
218
219 if dtype == DType.FP16:
220 return np.float16(f_tensor)
221 else:
222 f32_tensor = np.float32(f_tensor)
223 if dtype == DType.BF16:
224 # Floor the last 16 bits of each f32 value
225 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
Won Jeon2c34b462024-02-06 18:37:00 +0000226 elif dtype == DType.FP8E4M3:
227 return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
228 elif dtype == DType.FP8E5M2:
229 return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
Jeremy Johnson1271c442023-09-05 11:39:26 +0100230 else:
231 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700232 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100233 # All other integer types
234 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700235
Kevin Cheng989cb052021-04-28 16:29:44 -0700236 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700237 placeholders = []
238
Kevin Cheng989cb052021-04-28 16:29:44 -0700239 assert len(shape_list) == len(dtype_list)
240
Jeremy Johnson1271c442023-09-05 11:39:26 +0100241 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700242 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100243 if not self.args.lazy_data_gen:
244 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700245 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700246
247 return placeholders
248
Kevin Cheng989cb052021-04-28 16:29:44 -0700249 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700250 consts = []
251
Kevin Cheng989cb052021-04-28 16:29:44 -0700252 assert len(shape_list) == len(dtype_list)
253
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700255 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 if not self.args.lazy_data_gen:
257 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700258 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700259
260 return consts
261
262 def makeShape(self, rank):
263 if self.targetted_shape:
264 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800265 return np.int32(
266 self.rng.integers(
267 low=self.args.tensor_shape_range[0],
268 high=self.args.tensor_shape_range[1],
269 size=rank,
270 )
271 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700272
273 def setTargetShape(self, shape):
274 self.targetted_shape = shape
275
276 def randInt(self, low=0, high=256):
277 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
278
279 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100280 low, high = self.getDTypeRange(dtype)
281
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100282 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100283 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100284 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100285 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100286 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100287 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
288 return gtu.vect_f32_to_bf16(rand_f32)
Won Jeon2c34b462024-02-06 18:37:00 +0000289 elif dtype == DType.FP8E4M3:
290 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
291 return gtu.vect_f32_to_fp8e4m3(rand_f32)
292 elif dtype == DType.FP8E5M2:
293 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
294 return gtu.vect_f32_to_fp8e5m2(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700295 elif dtype == DType.BOOL:
296 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000297 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700298 # Special size
299 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700300
301 return np.int32(self.rng.integers(low, high, size=1))[0]
302
303 def shapeStr(self, shape):
304
305 sStr = []
306 # Convert to strings
307 for i in shape:
308 sStr.append(str(i))
309
Kevin Cheng550ccc52021-03-03 11:21:43 -0800310 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700311
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100312 def typeStr(self, dtype):
313 if isinstance(dtype, list) or isinstance(dtype, tuple):
314 assert len(dtype) >= 2
315 strs = [self.typeStr(t) for t in dtype]
316 # Limit types to the first 2 as the 3rd is the accumulator
317 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700318 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100319 if dtype in gtu.DTYPE_ATTRIBUTES:
320 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700321 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100322 raise Exception(
323 "Unknown dtype, cannot convert to string: {}".format(dtype)
324 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700325
Luke Hutton57287132023-02-06 14:54:18 +0000326 def constrictBatchSize(self, shape):
327 # Limit the batch size unless an explicit target shape set
328 if self.args.max_batch_size and not self.args.target_shapes:
329 shape[0] = min(shape[0], self.args.max_batch_size)
330 return shape
331
James Ward30124a82023-02-02 14:56:33 +0000332 def makeDimension(self):
333 return self.randInt(
334 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
335 )
336
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100337 def tensorComplianceMetaData(
338 self, op, inputType, argsDict, outputTensor, errorName
339 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000340 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
341 UNSUPPORTED_NON_FP32_INPUT_OPS = (
342 Op.MATMUL,
343 Op.CONV2D,
344 Op.FULLY_CONNECTED,
345 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000346 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000347 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000348 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100349 if (
350 errorName
351 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000352 or (
353 not gtu.dtypeIsSupportedByCompliance(inputType)
354 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
355 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100356 ):
357 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100358 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100359
Jeremy Johnson1271c442023-09-05 11:39:26 +0100360 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100361 compliance_tens = {
362 "mode": None,
363 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
364 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
365 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100366 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
367 mode = gtu.ComplianceMode.DOT_PRODUCT
368 compliance_tens["dot_product_info"] = {
369 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100370 "ks": int(argsDict["ksb"])
371 if "ksb" in argsDict
372 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100373 }
evacha019c96eef2024-02-07 11:21:55 +0000374 elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100375 mode = gtu.ComplianceMode.FP_SPECIAL
376 elif "compliance" in op and "ulp" in op["compliance"]:
377 mode = gtu.ComplianceMode.ULP
378 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000379 elif "compliance" in op and "relative" in op["compliance"]:
380 mode = gtu.ComplianceMode.RELATIVE
381 compliance_tens["relative_info"] = {
382 "max": argsDict["max_abs_value"],
383 "scale": op["compliance"]["relative"],
384 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100385 elif op["op"] == Op.REDUCE_PRODUCT:
386 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000387 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000388 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000389 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000390 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
391 compliance_tens["abs_error_info"] = {
392 "lower_bound": op["compliance"]["abs_error_lower_bound"]
393 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800394 elif op["op"] in (Op.SIN, Op.COS):
395 mode = gtu.ComplianceMode.ABS_ERROR
396 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
397 compliance_tens["abs_error_info"] = {
398 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
399 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100400 else:
401 mode = gtu.ComplianceMode.EXACT
402 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
403
404 return compliance_tens
405
406 # Build Op functions
407 # Create the output tensor (calling OutputShaper as needed)
408 # Do final tweaks to attributes (if necessary for errorIf)
409 # Add Op into graph
410 # Return resulting tensor information or BuildInfo
411
412 class BuildInfo:
413 """Enhanced build information containing result tensor and associated compliance dict."""
414
415 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000416 if isinstance(resultTensor, list):
417 assert complianceDict is None or isinstance(complianceDict, list)
418 self.resultTensorList = resultTensor
419 self.complianceDictList = complianceDict
420 else:
421 self.resultTensorList = [resultTensor]
422 if complianceDict is None:
423 self.complianceDictList = None
424 else:
425 self.complianceDictList = [complianceDict]
426
427 def getComplianceInfo(self):
428 if self.complianceDictList is None:
429 return None
430 else:
431 tens_dict = {}
432 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
433 if comp is not None:
434 tens_dict[tens.name] = comp
435
436 if tens_dict:
437 # Have some compliance data, so return the info
438 compliance = {
439 "version": "0.1",
440 "tensors": tens_dict,
441 }
442 else:
443 compliance = None
444 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700445
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000446 def build_unary(
447 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
448 ):
449 assert len(inputs) == 1
450 a = inputs[0]
451 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100452
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000453 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100454
455 # Ensure new output type has correct qinfo
456 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000457 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000458 qinfo = [
459 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000460 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000461 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100462
463 # Invalidate Input/Output list for error if checks.
464 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000465 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100466 pCount, cCount = op["operands"]
467 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000468 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
469 self, error_name, input_list, output_list
470 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100471
Les Bell729b0352021-11-24 10:28:21 +0000472 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100473 self.ser,
474 validator_fcns,
475 error_name,
476 op=op,
477 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000478 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000479 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000480 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100481 input_list=input_list,
482 output_list=output_list,
483 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000484 ):
485 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100486
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000487 attr = None
488 if op["op"] == Op.NEGATE:
489 attr = ts.TosaSerializerAttribute()
490 attr.NegateAttribute(qinfo[0], qinfo[1])
491
492 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000493
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000494 compliance = self.tensorComplianceMetaData(
495 op, a.dtype, args_dict, result_tensor, error_name
496 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000497 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700498
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000499 def build_binary_broadcast(
500 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
501 ):
502 assert len(inputs) == 2
503 a, b = inputs
504 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000505 self.ser, self.rng, a, b, error_name
506 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100507
508 # Invalidate Input/Output list for error if checks.
509 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000510 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100511 pCount, cCount = op["operands"]
512 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000513 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
514 self, error_name, input_list, output_list
515 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100516
Les Bell729b0352021-11-24 10:28:21 +0000517 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100518 self.ser,
519 validator_fcns,
520 error_name,
521 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 input1=a,
523 input2=b,
524 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000525 output_dtype=result_tensor.dtype,
526 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100527 input_list=input_list,
528 output_list=output_list,
529 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000530 ):
531 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100532
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000533 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000534
Jeremy Johnson9a758382023-11-07 16:27:35 +0000535 compliance = self.tensorComplianceMetaData(
536 op, a.dtype, args_dict, result_tensor, error_name
537 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000538
539 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700540
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100541 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700542 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700544 return result_tens
545
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000546 def build_arithmetic_right_shift(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000547 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000548 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000549 assert len(inputs) == 2
550 a, b = inputs
551 round = args_dict["round"]
552 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000553 self.ser, self.rng, a, b, error_name
554 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100555
556 # Invalidate Input/Output list for error if checks.
557 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000558 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100559 pCount, cCount = op["operands"]
560 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000561 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
562 self, error_name, input_list, output_list
563 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100564
Les Bell729b0352021-11-24 10:28:21 +0000565 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100566 self.ser,
567 validator_fcns,
568 error_name,
569 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000570 input1=a,
571 input2=b,
572 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000573 output_dtype=result_tensor.dtype,
574 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 input_list=input_list,
576 output_list=output_list,
577 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000578 ):
579 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800580
581 attr = ts.TosaSerializerAttribute()
582 attr.ArithmeticRightShiftAttribute(round)
583
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000585
586 compliance = self.tensorComplianceMetaData(
587 op, a.dtype, args_dict, result_tensor, error_name
588 )
589
590 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800591
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100592 def build_mul(
593 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
594 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000595 # Note that mul is binary operator but it has a shift value tensor
596 assert len(inputs) == 3
597 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100598
599 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 self.ser, self.rng, a, b, error_name
601 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700602
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100603 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100604 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100605 result_tensor.setDtype(DType.INT32)
606
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100607 if error_name == ErrorIf.WrongOutputType:
608 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
609 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100610 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100611
612 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000613 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100614 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100615 pCount, cCount = op["operands"]
616 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000617 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
618 self, error_name, input_list, output_list
619 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100620
Les Bell729b0352021-11-24 10:28:21 +0000621 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622 self.ser,
623 validator_fcns,
624 error_name,
625 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 input1=a,
627 input2=b,
628 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100629 output_dtype=result_tensor.dtype,
630 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100631 input_list=input_list,
632 output_list=output_list,
633 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000634 ):
635 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700636
Jeremy Johnson0a042992024-02-28 13:20:05 +0000637 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100638
639 compliance = self.tensorComplianceMetaData(
640 op, a.dtype, args_dict, result_tensor, error_name
641 )
642
643 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700644
Jeremy Johnson587cc842024-02-08 11:45:44 +0000645 def build_table(
646 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
647 ):
648 assert len(inputs) == 1
649 a = inputs[0]
650 table = args_dict["table"]
651 result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700652
Kevin Chengfe392ce2021-10-18 21:51:55 +0000653 attr = ts.TosaSerializerAttribute()
654 attr.TableAttribute(table)
655
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100656 # Invalidate Input/Output list for error if checks.
657 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000658 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100659 pCount, cCount = op["operands"]
660 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000661 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
662 self, error_name, input_list, output_list
663 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100664
Les Bell729b0352021-11-24 10:28:21 +0000665 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100666 self.ser,
667 validator_fcns,
668 error_name,
669 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000670 input_shape=a.shape,
671 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000672 output_dtype=result_tensor.dtype,
673 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100674 input_list=input_list,
675 output_list=output_list,
676 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000677 ):
678 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100679
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000680 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700681
Jeremy Johnson587cc842024-02-08 11:45:44 +0000682 compliance = self.tensorComplianceMetaData(
683 op, a.dtype, args_dict, result_tensor, error_name
684 )
685
686 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700687
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000688 def build_select(
689 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
690 ):
691 assert len(inputs) == 3
692 cond, a, b = inputs
693
694 result_tensor = OutputShaper.selectOp(
695 self.ser, self.rng, cond, a, b, error_name
696 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100697
698 # Invalidate Input/Output list for error if checks.
699 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000700 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100701 pCount, cCount = op["operands"]
702 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000703 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
704 self, error_name, input_list, output_list
705 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100706
Les Bell729b0352021-11-24 10:28:21 +0000707 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100708 self.ser,
709 validator_fcns,
710 error_name,
711 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000712 input1=cond,
713 input2=a,
714 input3=b,
715 input_shape=a.shape,
716 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000717 output_dtype=result_tensor.dtype,
718 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100719 input_list=input_list,
720 output_list=output_list,
721 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000722 ):
723 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100724
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000725 self.ser.addOperator(
726 op["op"],
727 input_list,
728 output_list,
729 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000730 compliance = self.tensorComplianceMetaData(
731 op, a.dtype, args_dict, result_tensor, error_name
732 )
733
734 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700735
Jeremy Johnsona0150012023-11-15 15:52:06 +0000736 def build_comparison(
737 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
738 ):
739 assert len(inputs) == 2
740 a, b = inputs
741
742 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000743 self.ser, self.rng, a, b, error_name
744 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100745
746 # Invalidate Input/Output list for error if checks.
747 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000748 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100749 pCount, cCount = op["operands"]
750 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000751 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
752 self, error_name, input_list, output_list
753 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100754
Les Bell729b0352021-11-24 10:28:21 +0000755 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100756 self.ser,
757 validator_fcns,
758 error_name,
759 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000760 input1=a,
761 input2=b,
762 input_shape=a.shape,
763 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000764 output_shape=result_tensor.shape,
765 output_dtype=result_tensor.dtype,
766 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100767 input_list=input_list,
768 output_list=output_list,
769 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000770 ):
771 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100772
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000773 self.ser.addOperator(
774 op["op"],
775 input_list,
776 output_list,
777 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000778
779 compliance = self.tensorComplianceMetaData(
780 op, a.dtype, args_dict, result_tensor, error_name
781 )
782 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700783
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000784 def build_argmax(
785 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
786 ):
787 assert len(inputs) == 1
788 a = inputs[0]
789 axis = args_dict["axis"]
790 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100791
792 # Invalidate Input/Output list for error if checks.
793 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000794 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100795 pCount, cCount = op["operands"]
796 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000797 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
798 self, error_name, input_list, output_list
799 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100800
Les Bell729b0352021-11-24 10:28:21 +0000801 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100802 self.ser,
803 validator_fcns,
804 error_name,
805 op=op,
806 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000807 input_shape=a.shape,
808 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000809 output_shape=result_tensor.shape,
810 output_dtype=result_tensor.dtype,
811 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100812 input_list=input_list,
813 output_list=output_list,
814 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000815 ):
816 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
818 attr = ts.TosaSerializerAttribute()
819 attr.AxisAttribute(axis)
820
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000821 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000822
823 compliance = self.tensorComplianceMetaData(
824 op, inputs[0].dtype, args_dict, result_tensor, error_name
825 )
826 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700827
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000828 def build_pool2d(
829 self,
830 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100831 inputs,
832 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000833 validator_fcns=None,
834 error_name=None,
835 qinfo=None,
836 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100837 assert len(inputs) == 1
838 input = inputs[0]
839 # max_pool has no accum_dtype
840 accum_dtype = (
841 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
842 )
843 stride = args_dict["stride"]
844 pad = args_dict["pad"]
845 kernel = args_dict["kernel"]
846
Jeremy Johnson0601f802023-11-08 16:28:09 +0000847 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000848 self.ser, self.rng, input, kernel, stride, pad, error_name
849 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100850
851 # Ensure new output type has correct qinfo
852 if error_name == ErrorIf.WrongInputType:
853 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000854 qinfo = [
855 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000856 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000857 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100858
859 # Invalidate Input/Output list for error if checks.
860 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000861 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100862 pCount, cCount = op["operands"]
863 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000864 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
865 self, error_name, input_list, output_list
866 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100867
Les Bell729b0352021-11-24 10:28:21 +0000868 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100869 self.ser,
870 validator_fcns,
871 error_name,
872 op=op,
873 input_shape=input.shape,
874 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000875 output_shape=result_tensor.shape,
876 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000877 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100878 kernel=kernel,
879 stride=stride,
880 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000881 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000882 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100883 input_list=input_list,
884 output_list=output_list,
885 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000886 ):
887 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700888
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000889 if qinfo is None:
890 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700891
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000892 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100893 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000894
895 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700896
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100897 compliance = self.tensorComplianceMetaData(
898 op, inputs[0].dtype, args_dict, result_tensor, error_name
899 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100900
901 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100902
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000903 def build_conv2d(
904 self,
905 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100906 inputs,
907 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000908 validator_fcns=None,
909 error_name=None,
910 qinfo=None,
911 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100912 assert len(inputs) == 3
913 ifm, filter, bias = inputs
914 accum_dtype = args_dict["acc_type"]
915 strides = args_dict["stride"]
916 padding = args_dict["pad"]
917 dilations = args_dict["dilation"]
918
Kevin Cheng550ccc52021-03-03 11:21:43 -0800919 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100920 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100921 self.ser,
922 self.rng,
923 ifm,
924 filter,
925 accum_dtype,
926 strides,
927 padding,
928 dilations,
929 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000930 )
931
932 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000933 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
934 DType.INT8,
935 DType.UINT8,
936 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000937 qinfo = [
938 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100939 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000940 ]
Les Bell0e027d42021-11-09 14:42:14 +0000941
942 # Invalidate Input/Output list for error_if checks.
943 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100944 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000945 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000946 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
947 self, error_name, input_list, output_list
948 )
Les Bell0e027d42021-11-09 14:42:14 +0000949
Les Bell729b0352021-11-24 10:28:21 +0000950 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000951 self.ser,
952 validator_fcns,
953 error_name,
954 op=op,
955 input_dtype=ifm.dtype,
956 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100957 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000958 qinfo=qinfo,
959 input_list=input_list,
960 num_operands=num_operands,
961 output_list=output_list,
962 pad=padding,
963 stride=strides,
964 dilation=dilations,
965 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100966 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100967 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000968 ):
969 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700970
Tai Lyd3797f02023-11-15 23:06:19 +0000971 # TODO - Test local_bound, for now set local bound attribute to False
972 local_bound = False
973
Eric Kunzee5e26762020-10-13 16:11:07 -0700974 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000975 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700976
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000977 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100978
979 compliance = self.tensorComplianceMetaData(
980 op, ifm.dtype, args_dict, result_tensor, error_name
981 )
982
983 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700984
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000985 def build_conv3d(
986 self,
987 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100988 inputs,
989 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000990 validator_fcns=None,
991 error_name=None,
992 qinfo=None,
993 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100994 assert len(inputs) == 3
995 ifm, filter, bias = inputs
996 accum_dtype = args_dict["acc_type"]
997 strides = args_dict["stride"]
998 padding = args_dict["pad"]
999 dilations = args_dict["dilation"]
1000
Kevin Cheng1533b852021-09-01 12:51:58 -07001001 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +00001002 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +01001003 self.ser,
1004 self.rng,
1005 ifm,
1006 filter,
1007 accum_dtype,
1008 strides,
1009 padding,
1010 dilations,
1011 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001012 )
1013
1014 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001015 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1016 DType.INT8,
1017 DType.UINT8,
1018 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001019 qinfo = [
1020 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +00001021 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001022 ]
Les Bell0e027d42021-11-09 14:42:14 +00001023
1024 # Invalidate Input/Output list for error_if checks.
1025 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +00001026 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001027 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1029 self, error_name, input_list, output_list
1030 )
Les Bell0e027d42021-11-09 14:42:14 +00001031
Les Bell729b0352021-11-24 10:28:21 +00001032 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001033 self.ser,
1034 validator_fcns,
1035 error_name,
1036 op=op,
1037 input_dtype=ifm.dtype,
1038 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001039 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001040 qinfo=qinfo,
1041 input_list=input_list,
1042 num_operands=num_operands,
1043 output_list=output_list,
1044 pad=padding,
1045 stride=strides,
1046 dilation=dilations,
1047 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001048 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001049 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001050 ):
1051 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001052
Tai Lyd3797f02023-11-15 23:06:19 +00001053 # TODO - Test local_bound, for now set local bound attribute to False
1054 local_bound = False
1055
Kevin Cheng1533b852021-09-01 12:51:58 -07001056 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001057 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001058
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001059 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001060
1061 compliance = self.tensorComplianceMetaData(
1062 op, ifm.dtype, args_dict, result_tensor, error_name
1063 )
1064
1065 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001066
Kevin Cheng550ccc52021-03-03 11:21:43 -08001067 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001068 self,
1069 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001070 inputs,
1071 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001072 validator_fcns=None,
1073 error_name=None,
1074 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001075 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001076 assert len(inputs) == 3
1077 ifm, filter, bias = inputs
1078 accum_dtype = args_dict["acc_type"]
1079 strides = args_dict["stride"]
1080 out_pad = args_dict["pad"]
1081 output_shape = args_dict["out_shape"]
1082
TatWai Chong24594f52022-06-08 00:48:04 -07001083 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001084 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001085 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001086 )
Les Bell0e027d42021-11-09 14:42:14 +00001087
1088 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001089 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1090 DType.INT8,
1091 DType.UINT8,
1092 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001093 qinfo = [
1094 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001095 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001096 ]
Les Bell0e027d42021-11-09 14:42:14 +00001097
1098 # Invalidate Input/Output list for error_if checks.
1099 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001100 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001101 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001102 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1103 self, error_name, input_list, output_list
1104 )
Les Bell0e027d42021-11-09 14:42:14 +00001105
Les Bell729b0352021-11-24 10:28:21 +00001106 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001107 self.ser,
1108 validator_fcns,
1109 error_name,
1110 op=op,
1111 input_dtype=ifm.dtype,
1112 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001113 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001114 qinfo=qinfo,
1115 input_list=input_list,
1116 num_operands=num_operands,
1117 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001118 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001119 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001120 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001121 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001122 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001123 ):
1124 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001125
Tai Lyd3797f02023-11-15 23:06:19 +00001126 # TODO - Test local_bound, for now set local bound attribute to False
1127 local_bound = False
1128
Eric Kunzee5e26762020-10-13 16:11:07 -07001129 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001130 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001131 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001132 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001133
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001134 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001135
1136 compliance = self.tensorComplianceMetaData(
1137 op, ifm.dtype, args_dict, result_tensor, error_name
1138 )
1139
1140 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001141
Kevin Cheng550ccc52021-03-03 11:21:43 -08001142 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 self,
1144 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001145 inputs,
1146 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001147 validator_fcns=None,
1148 error_name=None,
1149 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001150 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001151 assert len(inputs) == 3
1152 ifm, filter, bias = inputs
1153 accum_dtype = args_dict["acc_type"]
1154 strides = args_dict["stride"]
1155 padding = args_dict["pad"]
1156 dilations = args_dict["dilation"]
1157
Jeremy Johnson4f931302024-01-04 17:05:24 +00001158 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001159 self.ser,
1160 self.rng,
1161 ifm,
1162 filter,
1163 accum_dtype,
1164 strides,
1165 padding,
1166 dilations,
1167 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001168 )
1169
1170 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001171 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1172 DType.INT8,
1173 DType.UINT8,
1174 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001175 qinfo = [
1176 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001177 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001178 ]
Les Bell0e027d42021-11-09 14:42:14 +00001179
1180 # Invalidate Input/Output list for error_if checks.
1181 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001182 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001183 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001184 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1185 self, error_name, input_list, output_list
1186 )
Les Bell0e027d42021-11-09 14:42:14 +00001187
Les Bell729b0352021-11-24 10:28:21 +00001188 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001189 self.ser,
1190 validator_fcns,
1191 error_name,
1192 op=op,
1193 input_dtype=ifm.dtype,
1194 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001195 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001196 qinfo=qinfo,
1197 input_list=input_list,
1198 num_operands=num_operands,
1199 output_list=output_list,
1200 pad=padding,
1201 stride=strides,
1202 dilation=dilations,
1203 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001204 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001205 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001206 ):
1207 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001208
Tai Lyd3797f02023-11-15 23:06:19 +00001209 # TODO - Test local_bound, for now set local bound attribute to False
1210 local_bound = False
1211
Eric Kunzee5e26762020-10-13 16:11:07 -07001212 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001213 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001214
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001215 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001216
1217 compliance = self.tensorComplianceMetaData(
1218 op, ifm.dtype, args_dict, result_tensor, error_name
1219 )
1220
1221 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001222
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001223 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001224 self,
1225 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001226 inputs,
1227 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001228 validator_fcns=None,
1229 error_name=None,
1230 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001231 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001232 assert len(inputs) == 3
1233 ifm, filter, bias = inputs
1234 accum_dtype = args_dict["acc_type"]
1235
1236 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001237 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001238 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001239
1240 # Invalidate Input/Output list for error if checks.
1241 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001242 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001243 pCount, cCount = op["operands"]
1244 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001245 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1246 self, error_name, input_list, output_list
1247 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001248
Les Bell729b0352021-11-24 10:28:21 +00001249 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001250 self.ser,
1251 validator_fcns,
1252 error_name,
1253 op=op,
1254 input_shape=ifm.shape,
1255 input_dtype=ifm.dtype,
1256 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001257 output_shape=result_tensor.shape,
1258 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001259 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001260 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001261 input_list=input_list,
1262 output_list=output_list,
1263 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001264 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001265 ):
1266 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001267
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001268 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001269 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001270
1271 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001272
1273 compliance = self.tensorComplianceMetaData(
1274 op, ifm.dtype, args_dict, result_tensor, error_name
1275 )
1276
1277 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001278
James Ward8b390432022-08-12 20:48:56 +01001279 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001280 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001281 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001282 assert len(inputs) == 2
1283 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001284 accum_dtype = args_dict["acc_type"]
1285 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001286 self.ser, self.rng, a, b, accum_dtype, error_name
1287 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001288
1289 # Invalidate Input/Output list for error if checks.
1290 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001291 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001292 pCount, cCount = op["operands"]
1293 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001294 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1295 self, error_name, input_list, output_list
1296 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001297
Les Bell729b0352021-11-24 10:28:21 +00001298 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001299 self.ser,
1300 validator_fcns,
1301 error_name,
1302 op=op,
1303 input_shape=a.shape,
1304 input_dtype=a.dtype,
1305 input2_shape=b.shape,
1306 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001307 output_shape=result_tensor.shape,
1308 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001309 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001310 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001311 input_list=input_list,
1312 output_list=output_list,
1313 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001314 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001315 ):
1316 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001317
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001318 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001319 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001320
1321 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001322
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001323 compliance = self.tensorComplianceMetaData(
1324 op, a.dtype, args_dict, result_tensor, error_name
1325 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001326
1327 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001328
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001329 def build_reduce(
1330 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1331 ):
1332 assert len(inputs) == 1
1333 a = inputs[0]
1334 axis = args_dict["axis"]
1335 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001336
1337 # Invalidate Input/Output list for error if checks.
1338 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001339 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001340 pCount, cCount = op["operands"]
1341 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001342 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1343 self, error_name, input_list, output_list
1344 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001345
Les Bell729b0352021-11-24 10:28:21 +00001346 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001347 self.ser,
1348 validator_fcns,
1349 error_name,
1350 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001351 axis=axis,
1352 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001353 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001354 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001355 output_dtype=result_tensor.dtype,
1356 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001357 input_list=input_list,
1358 output_list=output_list,
1359 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001360 ):
1361 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001362
1363 attr = ts.TosaSerializerAttribute()
1364 attr.AxisAttribute(axis)
1365
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001366 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001367
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001368 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1369 # Number of products - needed for compliance
1370 args_dict["n"] = a.shape[axis]
1371
1372 compliance = self.tensorComplianceMetaData(
1373 op, a.dtype, args_dict, result_tensor, error_name
1374 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001375
1376 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001377
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001378 def build_clamp(
1379 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1380 ):
1381 assert len(inputs) == 1
1382 a = inputs[0]
1383
1384 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001385
Jeremy Johnson18e26662021-07-22 16:15:29 +01001386 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001387
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001388 if error_name == ErrorIf.MaxSmallerMin:
1389 # Make sure the numbers are different to invoke this error
1390 while v[0] == v[1]:
1391 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1392 max_val = min(v)
1393 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001394 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395 max_val = max(v)
1396 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001397
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001398 # Invalidate Input/Output list for error if checks.
1399 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001400 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001401 pCount, cCount = op["operands"]
1402 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001403 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1404 self, error_name, input_list, output_list
1405 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001406
Les Bell729b0352021-11-24 10:28:21 +00001407 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001408 self.ser,
1409 validator_fcns,
1410 error_name,
1411 op=op,
1412 max_val=max_val,
1413 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001414 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001415 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001416 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001417 output_dtype=result_tensor.dtype,
1418 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001419 input_list=input_list,
1420 output_list=output_list,
1421 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001422 ):
1423 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424
1425 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001426 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1427 if a.dtype == DType.FP16:
1428 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1429 min_val = min_val.astype(np.float32)
1430 max_val = max_val.astype(np.float32)
1431
1432 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001433 elif a.dtype in (DType.INT8, DType.INT16):
James Ward34071252022-12-07 15:48:47 +00001434 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Won Jeon2c34b462024-02-06 18:37:00 +00001435 else:
1436 # to avoid internal error for incorrect input types
1437 attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001438
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001439 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001440
1441 compliance = self.tensorComplianceMetaData(
1442 op, a.dtype, args_dict, result_tensor, error_name
1443 )
1444
1445 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001447 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1448 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001449 attr = ts.TosaSerializerAttribute()
1450
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001451 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001452
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001453 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001454 return result_tens
1455
1456 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001457 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1458 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001459
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001460 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001461 return result_tens
1462
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001463 def build_activation(
1464 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1465 ):
1466 assert len(inputs) == 1
1467 a = inputs[0]
1468
1469 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001470
1471 # Invalidate Input/Output list for error if checks.
1472 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001473 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001474 pCount, cCount = op["operands"]
1475 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001476 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1477 self, error_name, input_list, output_list
1478 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001479
Les Bell729b0352021-11-24 10:28:21 +00001480 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001481 self.ser,
1482 validator_fcns,
1483 error_name,
1484 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001485 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001486 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001487 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001488 output_dtype=result_tensor.dtype,
1489 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001490 input_list=input_list,
1491 output_list=output_list,
1492 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001493 ):
1494 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001495
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001496 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001497
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001498 compliance = self.tensorComplianceMetaData(
1499 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001500 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001501
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001502 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001503
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001504 def build_concat(
1505 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1506 ):
Won Jeon74342e52024-01-09 00:34:40 +00001507 if op["op"] == Op.CONCAT_SHAPE:
1508 axis = 0
1509 else:
1510 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001511 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001512 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001513
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001514 result_tensor = OutputShaper.concatOp(
1515 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001516 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001517
Matthew Haddon818ab902021-07-27 09:12:49 +01001518 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001519 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001520 input_tensor_names.append(tensor.name)
1521
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001522 # Invalidate Input/Output list for error if checks.
1523 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001524 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001525 pCount, cCount = op["operands"]
1526 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001527 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1528 self, error_name, input_list, output_list
1529 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001530
Les Bell729b0352021-11-24 10:28:21 +00001531 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001532 self.ser,
1533 validator_fcns,
1534 error_name,
1535 op=op,
1536 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001537 input_shape=inputs[0].shape,
1538 output_shape=result_tensor.shape,
1539 input_dtype=inputs[0].dtype,
1540 output_dtype=result_tensor.dtype,
1541 inputs=inputs,
1542 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001543 input_list=input_list,
1544 output_list=output_list,
1545 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001546 ):
1547 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001548
Won Jeon74342e52024-01-09 00:34:40 +00001549 if op["op"] == Op.CONCAT:
1550 attr = ts.TosaSerializerAttribute()
1551 attr.AxisAttribute(axis)
1552 else:
1553 assert op["op"] == Op.CONCAT_SHAPE
1554 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001555 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001556
1557 compliance = self.tensorComplianceMetaData(
1558 op, inputs[0].dtype, args_dict, result_tensor, error_name
1559 )
1560
1561 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001562
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001563 def build_pad(
1564 self,
1565 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001566 inputs,
1567 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001568 validator_fcns=None,
1569 error_name=None,
1570 qinfo=None,
1571 ):
Tai Lye095da72024-01-25 22:00:18 +00001572 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001573 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001574 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001575 padding = args_dict["pad"]
1576 pad_const_int = args_dict["pad_const_int"]
1577 pad_const_float = args_dict["pad_const_fp"]
1578
1579 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001580
Tai Lye095da72024-01-25 22:00:18 +00001581 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001582 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001583 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001584
Matthew Haddone807aae2021-10-11 18:12:58 +01001585 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001586 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001587 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001588 pCount, cCount = op["operands"]
1589 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001590 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1591 self, error_name, input_list, output_list
1592 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001593
Les Bell729b0352021-11-24 10:28:21 +00001594 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001595 self.ser,
1596 validator_fcns,
1597 error_name,
1598 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001599 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001600 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001601 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001602 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001603 pad=padding,
1604 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001605 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001606 input_list=input_list,
1607 output_list=output_list,
1608 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001609 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001610 ):
1611 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001612
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001613 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001614
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001615 compliance = self.tensorComplianceMetaData(
1616 op, a.dtype, args_dict, result_tensor, error_name
1617 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001618
1619 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
Won Jeona21b2e82023-08-10 10:33:01 +00001621 def build_dim(
1622 self,
1623 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001624 inputs,
1625 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001626 validator_fcns=None,
1627 error_name=None,
1628 qinfo=None,
1629 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001630 assert len(inputs) == 1
1631 a = inputs[0]
1632 axis = args_dict["axis"]
1633 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001634
1635 # Invalidate Input/Output list for error if checks.
1636 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001637 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001638 pCount, cCount = op["operands"]
1639 num_operands = pCount + cCount
1640 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1641 self, error_name, input_list, output_list
1642 )
1643
1644 if not TosaErrorValidator.evValidateErrorIfs(
1645 self.ser,
1646 validator_fcns,
1647 error_name,
1648 op=op,
1649 axis=axis,
1650 input_shape=a.shape,
1651 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001652 output_shape=result_tensor.shape,
1653 output_dtype=result_tensor.dtype,
1654 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001655 input_list=input_list,
1656 output_list=output_list,
1657 num_operands=num_operands,
1658 ):
1659 return None
1660
1661 attr = ts.TosaSerializerAttribute()
1662 attr.AxisAttribute(axis)
1663
1664 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001665 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001666
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001667 def build_reshape(
1668 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1669 ):
Tai Ly8690a082023-12-18 20:40:24 +00001670 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001671 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001672 shape = inputs[1]
1673 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001674 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001675 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001676 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001677
1678 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001679 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001680 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001681 pCount, cCount = op["operands"]
1682 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001683 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1684 self, error_name, input_list, output_list
1685 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001686
Les Bell729b0352021-11-24 10:28:21 +00001687 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001688 self.ser,
1689 validator_fcns,
1690 error_name,
1691 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001693 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001694 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001695 output_dtype=result_tensor.dtype,
1696 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001697 input_list=input_list,
1698 output_list=output_list,
1699 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001700 ):
1701 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001702
Tai Ly8690a082023-12-18 20:40:24 +00001703 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001704
1705 compliance = self.tensorComplianceMetaData(
1706 op, a.dtype, args_dict, result_tensor, error_name
1707 )
1708
1709 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001710
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001711 def build_reverse(
1712 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1713 ):
1714 assert len(inputs) == 1
1715 a = inputs[0]
1716 axis = args_dict["axis"]
1717 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001718
1719 # Invalidate Input/Output list for error if checks.
1720 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001721 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001722 pCount, cCount = op["operands"]
1723 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001724 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1725 self, error_name, input_list, output_list
1726 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001727
Les Bell729b0352021-11-24 10:28:21 +00001728 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001729 self.ser,
1730 validator_fcns,
1731 error_name,
1732 op=op,
1733 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001734 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001735 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001736 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001737 output_dtype=result_tensor.dtype,
1738 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001739 input_list=input_list,
1740 output_list=output_list,
1741 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001742 ):
1743 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
1745 attr = ts.TosaSerializerAttribute()
1746 attr.AxisAttribute(axis)
1747
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001748 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001749 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001750
evacha0198477222024-01-26 12:25:32 +00001751 def build_transpose(
1752 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1753 ):
1754 assert len(inputs) == 1
1755 a = inputs[0]
1756 perms = args_dict["perms"]
1757
1758 result_tensor = OutputShaper.transposeOp(
1759 self.ser, self.rng, a, perms, error_name
1760 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001761
Kevin Chengfe392ce2021-10-18 21:51:55 +00001762 attr = ts.TosaSerializerAttribute()
1763 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001764
Matthew Haddone807aae2021-10-11 18:12:58 +01001765 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001766 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001767 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001768 pCount, cCount = op["operands"]
1769 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001770 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1771 self, error_name, input_list, output_list
1772 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001773
Les Bell729b0352021-11-24 10:28:21 +00001774 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001775 self.ser,
1776 validator_fcns,
1777 error_name,
1778 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001779 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001780 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001781 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001782 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001783 output_dtype=result_tensor.dtype,
1784 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001785 input_list=input_list,
1786 output_list=output_list,
1787 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001788 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001789 ):
1790 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001791
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001792 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001793
1794 compliance = self.tensorComplianceMetaData(
1795 op, a.dtype, args_dict, result_tensor, error_name
1796 )
1797
1798 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001799
evacha017f7d4252024-01-24 12:08:09 +00001800 def build_slice(
1801 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1802 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001803 assert len(inputs) == 3
1804 a, start_var, size_var = inputs
1805 start_const = args_dict["start"]
1806 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001807
1808 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001809 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001810 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001811
1812 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001813 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001814 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001815 pCount, cCount = op["operands"]
1816 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001817 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1818 self, error_name, input_list, output_list
1819 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001820
Les Bell729b0352021-11-24 10:28:21 +00001821 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001822 self.ser,
1823 validator_fcns,
1824 error_name,
1825 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001827 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001828 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001829 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001830 start=start_const,
1831 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001832 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001833 input_list=input_list,
1834 output_list=output_list,
1835 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001836 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001837 ):
1838 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001839
Tai Ly8ead6c42024-02-14 22:35:44 +00001840 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001841
1842 compliance = self.tensorComplianceMetaData(
1843 op, a.dtype, args_dict, result_tensor, error_name
1844 )
1845
1846 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001847
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001848 def build_tile(
1849 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1850 ):
Tai Ly8690a082023-12-18 20:40:24 +00001851 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001852 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001853 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001854 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001855 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001856 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001857 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001858
1859 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001860 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001861 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001862 pCount, cCount = op["operands"]
1863 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001864 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1865 self, error_name, input_list, output_list
1866 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001867
Les Bell729b0352021-11-24 10:28:21 +00001868 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001869 self.ser,
1870 validator_fcns,
1871 error_name,
1872 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001873 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001874 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001875 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001876 output_dtype=result_tensor.dtype,
1877 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001878 input_list=input_list,
1879 output_list=output_list,
1880 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001881 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001882 ):
1883 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001884
Tai Ly8690a082023-12-18 20:40:24 +00001885 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001886
1887 compliance = self.tensorComplianceMetaData(
1888 op, a.dtype, args_dict, result_tensor, error_name
1889 )
1890
1891 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001892
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001893 def build_gather(
1894 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1895 ):
1896 assert len(inputs) == 2
1897 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001898
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001899 result_tensor = OutputShaper.gatherOp(
1900 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001901 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001902
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001903 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001904 input_list = [values.name, indices.name]
1905 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001906 pCount, cCount = op["operands"]
1907 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001908 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1909 self, error_name, input_list, output_list
1910 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001911
Les Bell729b0352021-11-24 10:28:21 +00001912 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001913 self.ser,
1914 validator_fcns,
1915 error_name,
1916 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001917 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001918 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001919 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001920 output_dtype=result_tensor.dtype,
1921 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001922 input_list=input_list,
1923 output_list=output_list,
1924 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001925 ):
1926 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001927
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001928 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001929
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001930 compliance = self.tensorComplianceMetaData(
1931 op, values.dtype, args_dict, result_tensor, error_name
1932 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001933
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001934 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001935
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001936 def build_scatter(
1937 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1938 ):
1939 assert len(inputs) == 3
1940 values_in, indices, input = inputs
1941 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001942 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001943 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001944
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001945 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001946 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001947 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001948 pCount, cCount = op["operands"]
1949 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001950 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1951 self, error_name, input_list, output_list
1952 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001953
Les Bell729b0352021-11-24 10:28:21 +00001954 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001955 self.ser,
1956 validator_fcns,
1957 error_name,
1958 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001959 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001960 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001961 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001962 output_dtype=result_tensor.dtype,
1963 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001964 input_list=input_list,
1965 output_list=output_list,
1966 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001967 ):
1968 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001969
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001970 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001971
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001972 compliance = self.tensorComplianceMetaData(
1973 op, values_in.dtype, args_dict, result_tensor, error_name
1974 )
1975
1976 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001977
Kevin Cheng550ccc52021-03-03 11:21:43 -08001978 def build_resize(
1979 self,
1980 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001981 inputs,
1982 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001983 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001984 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001985 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001986 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001987 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001988 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001989 scale_input = inputs[1]
1990 offset_input = inputs[2]
1991 border_input = inputs[3]
1992
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001993 mode = args_dict["mode"]
1994 scale = args_dict["scale"]
1995 offset = args_dict["offset"]
1996 border = args_dict["border"]
1997 output_dtype = args_dict["output_dtype"]
1998
1999 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002000 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002001 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002002 input,
2003 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002004 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002005 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002006 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002007 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002008 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002009 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002010 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002011
Matthew Haddon848efb42021-09-09 12:30:53 +01002012 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002013 input_list = [
2014 input.name,
2015 scale_input.name,
2016 offset_input.name,
2017 border_input.name,
2018 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002019 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002020 pCount, cCount = op["operands"]
2021 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002022 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2023 self, error_name, input_list, output_list
2024 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002025
Les Bell729b0352021-11-24 10:28:21 +00002026 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002027 self.ser,
2028 validator_fcns,
2029 error_name,
2030 op=op,
2031 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002032 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002033 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002034 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002035 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002036 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002037 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002038 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002039 input_list=input_list,
2040 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002041 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002042 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002043 ):
2044 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002045
Eric Kunzee5e26762020-10-13 16:11:07 -07002046 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002047 # write empty scale/offset/border into ResizeAttribute
2048 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002049 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002050
2051 compliance = self.tensorComplianceMetaData(
2052 op, input.dtype, args_dict, result_tensor, error_name
2053 )
2054
2055 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002056
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002057 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2058 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2059 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002060 self.ser.addOperator(
2061 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2062 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002063 return result_tens
2064
evacha0198477222024-01-26 12:25:32 +00002065 def build_const(
2066 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2067 ):
2068 assert len(inputs) == 1
2069 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002070 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002071
2072 compliance = self.tensorComplianceMetaData(
2073 op, val.dtype, args_dict, val, error_name
2074 )
2075
2076 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002077
2078 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002079 def build_cast(
2080 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2081 ):
2082 assert len(inputs) == 1
2083 val = inputs[0]
2084 out_dtype = args_dict["out_type"]
2085
2086 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002087 self.ser, self.rng, val, out_dtype, error_name
2088 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002089
2090 # Invalidate Input/Output list for error if checks.
2091 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002092 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002093 pCount, cCount = op["operands"]
2094 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002095 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2096 self, error_name, input_list, output_list
2097 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002098
Les Bell729b0352021-11-24 10:28:21 +00002099 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002100 self.ser,
2101 validator_fcns,
2102 error_name,
2103 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002104 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002105 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002106 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002107 output_dtype=result_tensor.dtype,
2108 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002109 input_list=input_list,
2110 output_list=output_list,
2111 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002112 ):
2113 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002114
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002115 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002116
2117 compliance = self.tensorComplianceMetaData(
2118 op, val.dtype, args_dict, result_tensor, error_name
2119 )
2120
2121 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002122
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002123 def build_rescale(
2124 self,
2125 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002126 inputs,
2127 args_dict,
2128 validator_fcns=None,
2129 error_name=None,
2130 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002131 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002132 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002133 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002134 multiplier_val = inputs[1]
2135 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002136 out_dtype = args_dict["output_dtype"]
2137 scale32 = args_dict["scale"]
2138 double_round = args_dict["double_round"]
2139 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002140 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002141 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002142
2143 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002144 self.ser, self.rng, val, out_dtype, error_name
2145 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002146
2147 if per_channel:
2148 nc = val.shape[-1]
2149 else:
2150 nc = 1
2151
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002152 in_type_width = gtu.dtypeWidth(val.dtype)
2153 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002154
Tai Ly8690a082023-12-18 20:40:24 +00002155 input_unsigned = False
2156 output_unsigned = False
2157
Kevin Cheng3a478572021-01-22 17:21:02 -08002158 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002159 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002160 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002161 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002162 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002163 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002164 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002165 elif error_name in [
2166 ErrorIf.InputZeroPointNotZero,
2167 ErrorIf.U16InputZeroPointNotValid,
2168 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002169 input_zp = self.randInt(-128, 128)
2170 if input_zp == 0:
2171 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002172 in_type_width += 1
2173 elif val.dtype == DType.UINT16:
2174 # Must come after ErrorIf.U16InputZeroPointNotValid check
2175 input_zp = self.rng.choice([0, 32768])
2176 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002177 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002178 else:
2179 input_zp = 0
2180
Kevin Cheng3a478572021-01-22 17:21:02 -08002181 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002182 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002183 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002184 elif out_dtype == DType.UINT8:
2185 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002186 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002187 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002188 elif error_name in [
2189 ErrorIf.OutputZeroPointNotZero,
2190 ErrorIf.U16OutputZeroPointNotValid,
2191 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002192 output_zp = self.randInt(-128, 128)
2193 if output_zp == 0:
2194 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002195 out_type_width += 1
2196 elif out_dtype == DType.UINT16:
2197 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2198 output_zp = self.rng.choice([0, 32768])
2199 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002200 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002201 else:
2202 output_zp = 0
2203
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002204 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2205 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002206
2207 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002208 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2209 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002210
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002211 logger.debug(
2212 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2213 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002214 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002215 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002216 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002217 assert val.placeholderFilename
2218 values = np.load(
2219 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2220 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002221 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2222 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2223 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002224 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2225 # Check we can safely convert to the expected dtype
2226 assert (
2227 val_adj.all() >= np.iinfo(values.dtype).min
2228 and val_adj.all() <= np.iinfo(values.dtype).max
2229 )
2230
2231 # Force casting to output datatype
2232 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2233
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002234 if not np.all(np.array_equal(values, val_adj)):
2235 # Values changed so overwrite file with new values
2236 np.save(
2237 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2238 val_adj,
2239 False,
2240 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002241
Matthew Haddonc2025212021-10-08 21:21:05 +01002242 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002243 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002244 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002245 pCount, cCount = op["operands"]
2246 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002247 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2248 self, error_name, input_list, output_list
2249 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002250
2251 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002252 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002253 self.ser,
2254 validator_fcns,
2255 error_name,
2256 op=op,
2257 input_dtype=val.dtype,
2258 output_dtype=out_dtype,
2259 input_shape=val.shape,
2260 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002261 scale32=scale32,
2262 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002263 input_list=input_list,
2264 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002265 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002266 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002267 ):
2268 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002269
Eric Kunzee5e26762020-10-13 16:11:07 -07002270 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002271 attr.RescaleAttribute(
2272 input_zp,
2273 output_zp,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002274 [],
2275 [],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002276 scale32,
2277 double_round,
2278 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002279 input_unsigned,
2280 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002281 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002282
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002283 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002284
2285 compliance = self.tensorComplianceMetaData(
2286 op, val.dtype, args_dict, result_tensor, error_name
2287 )
2288
2289 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002290
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002291 def _get_condition_tensor(self, op, cond, error_name):
2292 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002293 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002294 else:
2295 cond_type = DType.BOOL
2296 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2297 choice = self.rng.choice([1, 2])
2298 if choice == 1:
2299 cond_shape = [2]
2300 else:
2301 cond_shape = [1, 2]
2302 else:
2303 # Must be of size 1 (rank 0)
2304 cond_shape = []
2305 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2306 return cond_tens
2307
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002308 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002309 self,
2310 op,
2311 inputs,
2312 args_dict,
2313 validator_fcns=None,
2314 error_name=None,
2315 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002316 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002317 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002318 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002319 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002320 assert len(inputs) == 2
2321 then_tens, else_tens = inputs
2322
2323 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002324
2325 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002326 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002327
2328 # Make then/else tensors
2329 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002330
Jeremy Johnson587cc842024-02-08 11:45:44 +00002331 dtype = DType.INT32
2332
Matthew Haddon630c17c2021-10-14 15:05:41 +01002333 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002334 if error_name in [
2335 ErrorIf.CondIfOutputListThenGraphMismatch,
2336 ErrorIf.CondIfOutputListElseGraphMismatch,
2337 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002338 incorrect_shape = deepcopy(then_tens.shape)
2339 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002340 incorrect_shape[i] += (
2341 self.rng.choice([-3, -2, 2, 3])
2342 if incorrect_shape[i] > 3
2343 else self.rng.choice([1, 2, 4])
2344 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002345 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2346
Jeremy Johnson18e26662021-07-22 16:15:29 +01002347 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2348 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002349
2350 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002351 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002352
2353 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002354 then_block = "THEN_BLOCK"
2355 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002356 attr = ts.TosaSerializerAttribute()
2357 attr.CondIfAttribute(then_block, else_block)
2358
2359 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002360 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002361
Jerry Ge9e94af82022-10-27 09:57:00 -07002362 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002363 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002364 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002365 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002366 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002367 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002368 self.ser.addOutputTensor(then_tens)
2369
Jerry Ge9e94af82022-10-27 09:57:00 -07002370 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002371 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002372 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002373 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002374 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002375 self.ser.addOutputTensor(else_tens)
2376
Les Bell729b0352021-11-24 10:28:21 +00002377 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002378 self.ser,
2379 validator_fcns,
2380 error_name,
2381 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002382 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002383 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002384 ):
2385 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002386
Jeremy Johnson587cc842024-02-08 11:45:44 +00002387 compliance = self.tensorComplianceMetaData(
2388 op, dtype, args_dict, result_tensor, error_name
2389 )
2390
2391 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002392
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002393 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002394 self,
2395 op,
2396 inputs,
2397 args_dict,
2398 validator_fcns=None,
2399 error_name=None,
2400 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002401 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002402 # For cond_if with a binary op in the then/else blocks, take a and b and
2403 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002404 assert len(inputs) == 2
2405 a, b = inputs
2406
2407 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002408
2409 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002410 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002411
Jeremy Johnson587cc842024-02-08 11:45:44 +00002412 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002413
2414 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002415 then_block = "THEN_BLOCK"
2416 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002417 attr = ts.TosaSerializerAttribute()
2418 attr.CondIfAttribute(then_block, else_block)
2419
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002420 if error_name in [
2421 ErrorIf.CondIfInputListThenGraphMismatch,
2422 ErrorIf.CondIfInputListElseGraphMismatch,
2423 ErrorIf.CondIfOutputListElseGraphMismatch,
2424 ErrorIf.CondIfOutputListThenGraphMismatch,
2425 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002426 incorrect_shape = a.shape.copy()
2427 for i in range(len(incorrect_shape)):
2428 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2429 incorrect_block_input = deepcopy(a)
2430 incorrect_block_input.shape = incorrect_shape
2431
Eric Kunzee5e26762020-10-13 16:11:07 -07002432 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002433 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002434 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002435 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002436
James Ward24dbc422022-10-19 12:20:31 +01002437 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002438 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002439 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002440 then_op, else_op = (
2441 self.TOSA_OP_LIST["logical_right_shift"],
2442 self.TOSA_OP_LIST["logical_left_shift"],
2443 )
Les Bell6040b4d2021-10-11 12:50:31 +01002444 else:
2445 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002446
Jeremy Johnson587cc842024-02-08 11:45:44 +00002447 # Determine the element-wise binary operation that compliance will need to
2448 # check the results of
2449 compliance_op = then_op if cond else else_op
2450
2451 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002452 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002453 if (
2454 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2455 and block == then_block
2456 ) or (
2457 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2458 and block == else_block
2459 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002460 self.ser.addInputTensor(incorrect_block_input)
2461 self.ser.addInputTensor(b)
2462 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002463 elif (
2464 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2465 and block == then_block
2466 ) or (
2467 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2468 and block == else_block
2469 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002470 self.ser.addInputTensor(a)
2471 self.ser.addInputTensor(b)
2472 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2473 else:
2474 self.ser.addInputTensor(a)
2475 self.ser.addInputTensor(b)
2476 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002477 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002478
Les Bell729b0352021-11-24 10:28:21 +00002479 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002480 self.ser,
2481 validator_fcns,
2482 error_name,
2483 op=op,
2484 a=a,
2485 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002486 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002487 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002488 ):
2489 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002490
Jeremy Johnson587cc842024-02-08 11:45:44 +00002491 compliance = self.tensorComplianceMetaData(
2492 compliance_op, a.dtype, args_dict, result_tensor, error_name
2493 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002494
Jeremy Johnson587cc842024-02-08 11:45:44 +00002495 return TosaTestGen.BuildInfo(result_tensor, compliance)
2496
2497 def build_while_loop(
2498 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2499 ):
2500 assert len(inputs) == 1
2501 a = inputs[0]
2502 iter_val = args_dict["iterations"]
2503
Kevin Cheng550ccc52021-03-03 11:21:43 -08002504 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002505
Kevin Cheng550ccc52021-03-03 11:21:43 -08002506 cond_block = "COND_BLOCK"
2507 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002508
2509 attr = ts.TosaSerializerAttribute()
2510 attr.WhileLoopAttribute(cond_block, body_block)
2511
2512 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002513 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002514 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002515 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002516
2517 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002518 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2519 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002520 if error_name == ErrorIf.InputListOutputListMismatch:
2521 incorrect_acc = deepcopy(acc)
2522 for i in range(len(incorrect_acc.shape)):
2523 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2524 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2525 else:
2526 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002527
2528 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002529 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002530 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002531 [iter.name, a.name, acc.name],
2532 [iter_out.name, a_out.name, acc_out.name],
2533 attr,
2534 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002535 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002536
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002537 if error_name in [
2538 ErrorIf.InputListCondGraphMismatch,
2539 ErrorIf.InputListBodyGraphInputMismatch,
2540 ErrorIf.InputListBodyGraphOutputMismatch,
2541 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002542 incorrect_iter = deepcopy(iter)
2543 for i in range(len(incorrect_iter.shape)):
2544 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2545 if len(incorrect_iter.shape) == 0:
2546 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2547
2548 incorrect_acc = deepcopy(acc)
2549 for i in range(len(incorrect_acc.shape)):
2550 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2551
Eric Kunzee5e26762020-10-13 16:11:07 -07002552 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002553 self.ser.addBasicBlock(cond_block)
2554
Matthew Haddon630c17c2021-10-14 15:05:41 +01002555 if error_name == ErrorIf.InputListCondGraphMismatch:
2556 self.ser.addInputTensor(incorrect_iter)
2557 self.ser.addInputTensor(a)
2558 self.ser.addInputTensor(incorrect_acc)
2559 else:
2560 self.ser.addInputTensor(iter)
2561 self.ser.addInputTensor(a)
2562 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002563 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002564
2565 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002566 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002567 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002568 cond_type = DType.BOOL
2569 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2570 choice = self.rng.choice([1, 2])
2571 if choice == 1:
2572 cond_shape = [3]
2573 else:
2574 cond_shape = [1, 2]
2575 else:
2576 cond_shape = []
2577 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002578
Kevin Cheng550ccc52021-03-03 11:21:43 -08002579 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002580
2581 # BODY block (input: a, acc, iter, output: a, acc, iter)
2582 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002583 self.ser.addBasicBlock(body_block)
2584
Matthew Haddon630c17c2021-10-14 15:05:41 +01002585 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2586 self.ser.addInputTensor(incorrect_iter)
2587 self.ser.addInputTensor(a)
2588 self.ser.addInputTensor(incorrect_acc)
2589 else:
2590 self.ser.addInputTensor(iter)
2591 self.ser.addInputTensor(a)
2592 self.ser.addInputTensor(acc)
2593
Kevin Cheng550ccc52021-03-03 11:21:43 -08002594 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002595
2596 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002597 iter_body_out = self.ser.addIntermediate(
2598 incorrect_iter.shape, incorrect_iter.dtype
2599 )
2600 acc_body_out = self.ser.addIntermediate(
2601 incorrect_acc.shape, incorrect_acc.dtype
2602 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002603 else:
2604 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2605 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2606
Eric Kunzee5e26762020-10-13 16:11:07 -07002607 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2608 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2609 self.ser.addOutputTensor(iter_body_out)
2610 self.ser.addOutputTensor(a)
2611 self.ser.addOutputTensor(acc_body_out)
2612
Les Bell729b0352021-11-24 10:28:21 +00002613 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002614 self.ser,
2615 validator_fcns,
2616 error_name,
2617 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002618 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002619 ):
2620 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002621
Jeremy Johnson587cc842024-02-08 11:45:44 +00002622 compliance = self.tensorComplianceMetaData(
2623 op, a.dtype, args_dict, acc_out, error_name
2624 )
2625
2626 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002627
Luke Hutton57287132023-02-06 14:54:18 +00002628 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002629 self,
2630 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002631 inputs,
2632 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002633 validator_fcns=None,
2634 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002635 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002636 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002637 assert len(inputs) == 2
2638 val1, val2 = inputs
2639 inverse = args_dict["inverse"]
2640
Luke Hutton57287132023-02-06 14:54:18 +00002641 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2642
2643 input_names = [val1.name, val2.name]
2644 pCount, cCount = op["operands"]
2645 num_operands = pCount + cCount
2646
2647 output_names = [res.name for res in results]
2648 output_shapes = [res.shape for res in results]
2649 output_dtypes = [res.dtype for res in results]
2650
2651 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2652 self, error_name, input_names, output_names
2653 )
2654
2655 if not TosaErrorValidator.evValidateErrorIfs(
2656 self.ser,
2657 validator_fcns,
2658 error_name,
2659 op=op,
2660 inverse=inverse,
2661 input1=val1,
2662 input2=val2,
2663 input_shape=val1.shape,
2664 input_dtype=val1.dtype,
2665 output_shape=output_shapes,
2666 output_dtype=output_dtypes,
2667 result_tensors=results,
2668 input_list=input_names,
2669 output_list=output_names,
2670 num_operands=num_operands,
2671 ):
2672 return None
2673
Tai Lyd3797f02023-11-15 23:06:19 +00002674 # TODO - Test local_bound, for now set local bound attribute to False
2675 local_bound = False
2676
Luke Hutton57287132023-02-06 14:54:18 +00002677 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002678 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002679
2680 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002681
2682 compliance = []
2683 for res in results:
2684 compliance.append(
2685 self.tensorComplianceMetaData(
2686 op, val1.dtype, args_dict, res, error_name
2687 )
2688 )
2689
2690 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002691
Tai Lyd3797f02023-11-15 23:06:19 +00002692 def build_rfft2d(
2693 self,
2694 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002695 inputs,
2696 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002697 validator_fcns=None,
2698 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002699 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002700 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002701 assert len(inputs) == 1
2702 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002703 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2704
2705 input_names = [val.name]
2706 pCount, cCount = op["operands"]
2707 num_operands = pCount + cCount
2708
2709 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002710 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002711 output_dtypes = [res.dtype for res in results]
2712
2713 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2714 self, error_name, input_names, output_names
2715 )
2716
2717 if not TosaErrorValidator.evValidateErrorIfs(
2718 self.ser,
2719 validator_fcns,
2720 error_name,
2721 op=op,
2722 input_shape=val.shape,
2723 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002724 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002725 output_dtype=output_dtypes,
2726 result_tensors=results,
2727 input_list=input_names,
2728 output_list=output_names,
2729 num_operands=num_operands,
2730 ):
2731 return None
2732
Tai Lyd3797f02023-11-15 23:06:19 +00002733 # TODO - Test local_bound, for now set local bound attribute to False
2734 local_bound = False
2735
2736 attr = ts.TosaSerializerAttribute()
2737 attr.RFFTAttribute(local_bound)
2738
2739 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002740
2741 compliance = []
2742 for res in results:
2743 compliance.append(
2744 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2745 )
2746
2747 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002748
Won Jeon74342e52024-01-09 00:34:40 +00002749 def build_shape_op(
2750 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2751 ):
2752 assert len(inputs) == 2
2753 a, b = inputs
2754
2755 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2756
2757 # Invalidate Input/Output list for error if checks.
2758 input_list = [a.name, b.name]
2759 output_list = [result_tensor.name]
2760 pCount, cCount = op["operands"]
2761 num_operands = pCount + cCount
2762 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2763 self, error_name, input_list, output_list
2764 )
2765
2766 if not TosaErrorValidator.evValidateErrorIfs(
2767 self.ser,
2768 validator_fcns,
2769 error_name,
2770 op=op,
2771 input1=a,
2772 input2=b,
2773 input_shape=a.shape,
2774 input_dtype=a.dtype,
2775 output_shape=result_tensor.shape,
2776 output_dtype=result_tensor.dtype,
2777 result_tensors=[result_tensor],
2778 input_list=input_list,
2779 output_list=output_list,
2780 num_operands=num_operands,
2781 ):
2782 return None
2783
2784 self.ser.addOperator(
2785 op["op"],
2786 input_list,
2787 output_list,
2788 )
2789 compliance = self.tensorComplianceMetaData(
2790 op, a.dtype, args_dict, result_tensor, error_name
2791 )
2792
2793 return TosaTestGen.BuildInfo(result_tensor, compliance)
2794
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002795 def create_filter_lists(
2796 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2797 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002798 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2799 default_test_rank_range = range(1, 5)
2800 if not shapeFilter:
2801 shapeFilter = [None]
2802
2803 # Calculate the filters based on what is requested and what the operator allows
2804 rmin, rmax = op["rank"]
2805 if rankFilter is not None:
2806 cleanRankFilter = []
2807 # Ensure rankFilter values are allowed by operator
2808 for rank in rankFilter:
2809 if rank >= rmin and rank <= rmax:
2810 cleanRankFilter.append(rank)
2811 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002812 # Ensure default behaviour is bounded by default range or by operator,
2813 # whichever is the smaller range of ranks.
2814 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002815 cleanRankFilter = (
2816 opRankRange
2817 if len(opRankRange) <= len(default_test_rank_range)
2818 else default_test_rank_range
2819 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002820 else:
2821 cleanRankFilter = range(rmin, rmax + 1)
2822
2823 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002824
Matthew Haddon1c00b712021-10-01 15:51:03 +01002825 if dtypeFilter is not None:
2826 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002827 # Create list of operator dtypes filtered by requested dtypes
2828 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002829 if dtype in dtypeFilter or (
2830 isinstance(dtype, list) and dtype[0] in dtypeFilter
2831 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002832 cleanDtypeFilter.append(dtype)
2833 else:
2834 cleanDtypeFilter = dtypes
2835
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002836 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002837 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002838 "shapeFilter": shapeFilter,
2839 "rankFilter": cleanRankFilter,
2840 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002841 }
2842 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002843 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002844 if validator is not None:
2845 validator_info = validator(check=False, op=op)
2846 else:
2847 return None
2848
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002849 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002850
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002851 # Set parameters as required
2852 if error_arguments["rank"] is not None:
2853 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002854 else:
2855 rankFilter = cleanRankFilter
2856
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002857 if error_arguments["dtype"] is not None:
2858 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002859 else:
2860 dtypeFilter = cleanDtypeFilter
2861
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002862 if error_arguments["shape"] is not None:
2863 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002864 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002865 shapeFilter = shapeFilter[
2866 :2
2867 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002868
2869 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 "shapeFilter": shapeFilter,
2871 "rankFilter": rankFilter,
2872 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002873 }
2874 return filterDict
2875
Kevin Cheng550ccc52021-03-03 11:21:43 -08002876 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002877 self,
2878 opName,
2879 shapeFilter=[None],
2880 rankFilter=None,
2881 dtypeFilter=None,
2882 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002883 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002884
2885 try:
2886 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002887 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002888 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002889
2890 # Initialize a new random number generator
2891 self.rng = np.random.default_rng(self.random_seed)
2892
Jeremy Johnson1271c442023-09-05 11:39:26 +01002893 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002894
Eric Kunzee5e26762020-10-13 16:11:07 -07002895 # Test list consists of a tuple of:
2896 # (opName, testNameStr, dtype, shapeList, argumentsList)
2897 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002898 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002899 error_if_validators = op["error_if_validators"]
2900 else:
2901 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002902
Matthew Haddon1c00b712021-10-01 15:51:03 +01002903 for validator in error_if_validators:
2904 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002905 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002906 else:
2907 error_name = None
2908
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002909 filterDict = self.create_filter_lists(
2910 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2911 )
2912 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002913 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002914 cleanRankFilter = filterDict["rankFilter"]
2915 cleanDtypeFilter = filterDict["dtypeFilter"]
2916 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002917 logger.debug(
2918 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2919 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002920
2921 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002922 for t in cleanDtypeFilter:
2923 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002924 # Filter out by rank
2925 if shape is not None and len(shape) != r:
2926 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002927 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002928 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002929
Matthew Haddon74567092021-07-16 15:38:20 +01002930 shapeStr = self.shapeStr(shapeList[0])
2931 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002932
Matthew Haddon74567092021-07-16 15:38:20 +01002933 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2934 argList = []
2935 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002936 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002937 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002938 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002939
Matthew Haddon74567092021-07-16 15:38:20 +01002940 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002941 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002942 if argStr:
2943 testStr = "{}_{}_{}_{}".format(
2944 opName, shapeStr, typeStr, argStr
2945 )
2946 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002947 testStr = "{}_{}_{}".format(
2948 opName, shapeStr, typeStr
2949 )
2950 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002951 if argStr:
2952 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2953 opName, error_name, shapeStr, typeStr, argStr
2954 )
2955 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002956 testStr = "{}_ERRORIF_{}_{}_{}".format(
2957 opName, error_name, shapeStr, typeStr
2958 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002959
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002960 testList.append(
2961 (opName, testStr, t, error_name, shapeList, args)
2962 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002963
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002964 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002965 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2966 if "invalid_test_validators" in op:
2967 invalid_test_validators = op["invalid_test_validators"]
2968 clean_testList = []
2969 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002970 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002971 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002972 if validator_fcn(
2973 opName=test[0],
2974 input_dtype=test[2],
2975 shapeList=test[4],
2976 args=test[5],
2977 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002978 remove_test = True
2979 if not remove_test:
2980 clean_testList.append(test)
2981 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002982
2983 return testList
2984
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002985 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002986 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002987 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002988 try:
2989 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002990 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002991 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002992
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002993 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01002994
Eric Kunzee5e26762020-10-13 16:11:07 -07002995 # Create a serializer
2996 self.createSerializer(opName, testStr)
2997
Jeremy Johnson1271c442023-09-05 11:39:26 +01002998 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002999 if "error_if_validators" in op:
3000 error_if_validators = op["error_if_validators"]
3001 else:
3002 error_if_validators = None
3003
Kevin Cheng550ccc52021-03-03 11:21:43 -08003004 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003005 num_operands = pCount + cCount
3006
3007 if isinstance(dtype_or_dtypeList, list):
3008 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003009 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003010 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003011 else:
3012 dtypeList = [dtype_or_dtypeList] * (num_operands)
3013
Won Jeon74342e52024-01-09 00:34:40 +00003014 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003015 assert (
3016 len(shapeList) == num_operands
3017 ), "shapeList length {} must match number of operands {}".format(
3018 len(shapeList), num_operands
3019 )
3020 assert (
3021 len(dtypeList) == num_operands
3022 ), "dtypeList length {} must match number of operands {}".format(
3023 len(dtypeList), num_operands
3024 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003025
3026 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003027 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003028 except KeyError:
3029 qgen = None
3030
3031 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003032
Matthew Haddon1c00b712021-10-01 15:51:03 +01003033 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003034 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003035 else:
3036 qinfo = None
3037
Jeremy Johnson1271c442023-09-05 11:39:26 +01003038 # Extra meta data for the desc.json
3039 tensMeta = {}
3040
Jeremy Johnson587cc842024-02-08 11:45:44 +00003041 # Check we are using the new interface with an argsDict dictionary
3042 assert isinstance(
3043 argsDict, dict
3044 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003045
Jeremy Johnson587cc842024-02-08 11:45:44 +00003046 # New interface with args info in dictionary
3047 assert "dg_type" in argsDict
3048 tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
3049 if tvgInfo.dataGenDict:
3050 tensMeta["data_gen"] = tvgInfo.dataGenDict
3051 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003052
Jeremy Johnson587cc842024-02-08 11:45:44 +00003053 result = build_fcn(
3054 self,
3055 op,
3056 tens,
3057 argsDict,
3058 validator_fcns=error_if_validators,
3059 error_name=error_name,
3060 qinfo=qinfo,
3061 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003062
Jeremy Johnson1271c442023-09-05 11:39:26 +01003063 if result:
Les Bell729b0352021-11-24 10:28:21 +00003064 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003065 if isinstance(result, TosaTestGen.BuildInfo):
3066 # Add the compliance meta data (if any)
3067 compliance = result.getComplianceInfo()
3068 if compliance:
3069 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003070 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003071 else:
3072 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003073 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003074
Eric Kunzee5e26762020-10-13 16:11:07 -07003075 def createDynamicOpLists(self):
3076
Jeremy Johnson00423432022-09-12 17:27:37 +01003077 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3078 # Already created these lists (can occur when class is initialized more than once)
3079 return
3080
Eric Kunzee5e26762020-10-13 16:11:07 -07003081 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003082 if not self.args.level8k:
3083 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3084 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3085 else:
3086 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3087 KERNELS_2D = [[1, bigK], [bigK, 2]]
3088 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003089
Kevin Cheng1533b852021-09-01 12:51:58 -07003090 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003091 testName = "conv2d_{}x{}".format(k[0], k[1])
3092 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3093 self.TOSA_OP_LIST[testName]["filter"] = k
3094 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003095 self.TOSA_OP_LIST[testName]["real_name"] = "conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003096
Kevin Cheng550ccc52021-03-03 11:21:43 -08003097 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3098 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3099 "depthwise_conv2d_TEMPLATE"
3100 ].copy()
3101 self.TOSA_OP_LIST[testName]["filter"] = k
3102 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003103 self.TOSA_OP_LIST[testName]["real_name"] = "depthwise_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003104
Kevin Cheng550ccc52021-03-03 11:21:43 -08003105 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3106 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3107 "transpose_conv2d_TEMPLATE"
3108 ].copy()
3109 self.TOSA_OP_LIST[testName]["filter"] = k
3110 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003111 self.TOSA_OP_LIST[testName]["real_name"] = "transpose_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003112
Kevin Cheng1533b852021-09-01 12:51:58 -07003113 for k in KERNELS_3D:
3114 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3115 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3116 self.TOSA_OP_LIST[testName]["filter"] = k
3117 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003118 self.TOSA_OP_LIST[testName]["real_name"] = "conv3d"
Kevin Cheng1533b852021-09-01 12:51:58 -07003119
Eric Kunzee5e26762020-10-13 16:11:07 -07003120 # Delete any templates after having created any dynamic ops
3121 # This is a two-pass operation because it's bad practice to delete
3122 # keys from dictionaries while iterating
3123 keyList = []
3124 for k in self.TOSA_OP_LIST:
3125 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003126 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003127 keyList.append(k)
3128 continue
3129 except KeyError:
3130 pass
3131
3132 for k in keyList:
3133 del self.TOSA_OP_LIST[k]
3134
3135 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003136 """Fill in default fields for ops if they aren't already specified.
3137 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003138 for op in self.TOSA_OP_LIST:
3139
3140 # Required fields
3141 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003142 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003143 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003144 raise Exception(
3145 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3146 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003147
3148 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003149 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003150 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003151 raise Exception(
3152 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3153 op
3154 )
3155 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003156
3157 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003158 _ = self.TOSA_OP_LIST[op]["types"]
3159 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003160 raise Exception(
3161 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3162 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003163
3164 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003165 _ = self.TOSA_OP_LIST[op]["op"]
3166 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003167 raise Exception(
3168 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3169 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003170
3171 # Put in default rank range, if missing
3172 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003173 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003174 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003175 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003176
3177 # Tensor operator list
3178 # 'op': op name
3179 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003180 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3181 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003182 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3183 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003184 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003185
Kevin Cheng550ccc52021-03-03 11:21:43 -08003186 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003187 TYPE_INT_FP = [
3188 DType.INT8,
3189 DType.INT16,
3190 DType.INT32,
3191 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003192 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003193 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003194 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003195
Kevin Cheng550ccc52021-03-03 11:21:43 -08003196 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003197 TYPE_FI32 = [
3198 DType.FP32,
3199 DType.FP16,
3200 DType.BF16,
3201 DType.INT32,
3202 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003203 TYPE_FIB = [
3204 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003205 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003206 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003207 DType.INT8,
3208 DType.INT16,
3209 DType.INT32,
3210 DType.BOOL,
3211 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003212 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003213
Won Jeon2c34b462024-02-06 18:37:00 +00003214 TYPE_NARROW_INT_FP = [
3215 DType.INT8,
3216 DType.INT16,
3217 DType.FP16,
3218 DType.BF16,
3219 DType.FP32,
3220 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003221
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003222 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003223 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003224 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003225 [DType.INT8, DType.INT8, DType.INT32],
3226 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003227 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003228 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003229 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003230 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003231 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3232 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003233 ]
3234
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003235 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003236
3237 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003238 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003239 "argmax": {
3240 "op": Op.ARGMAX,
3241 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003242 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003243 "build_fcn": (
3244 build_argmax,
3245 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003246 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003247 TosaArgGen.agAxis,
3248 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003249 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003250 "error_if_validators": (
3251 TosaErrorValidator.evAxisSmallerZero,
3252 TosaErrorValidator.evAxisLargerRank,
3253 TosaErrorValidator.evArgmaxOutputRankMismatch,
3254 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3255 TosaErrorValidator.evWrongRank,
3256 TosaErrorValidator.evWrongInputType,
3257 TosaErrorValidator.evWrongOutputType,
3258 TosaErrorValidator.evWrongInputList,
3259 TosaErrorValidator.evWrongOutputList,
3260 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003261 "data_gen": {
3262 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3263 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003264 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003265 "avg_pool2d": {
3266 "op": Op.AVG_POOL2D,
3267 "operands": (1, 0),
3268 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003269 "build_fcn": (
3270 build_pool2d,
3271 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003272 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003273 TosaArgGen.agPooling,
3274 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003275 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003276 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003277 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003278 "error_if_validators": (
3279 TosaErrorValidator.evKernelSmallerOne,
3280 TosaErrorValidator.evStrideSmallerOne,
3281 TosaErrorValidator.evPadSmallerZero,
3282 TosaErrorValidator.evWrongRank,
3283 TosaErrorValidator.evWrongInputType,
3284 TosaErrorValidator.evWrongOutputType,
3285 TosaErrorValidator.evWrongInputList,
3286 TosaErrorValidator.evWrongOutputList,
3287 TosaErrorValidator.evInputZeroPointNotZero,
3288 TosaErrorValidator.evOutputZeroPointNotZero,
3289 TosaErrorValidator.evPadLargerEqualKernel,
3290 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003291 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003292 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003293 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003294 "data_gen": {
3295 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3296 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003297 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003298 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003299 "conv2d_TEMPLATE": {
3300 "op": Op.CONV2D,
3301 "operands": (1, 2),
3302 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003303 "build_fcn": (
3304 build_conv2d,
3305 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003306 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003307 TosaArgGen.agConv,
3308 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003309 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003310 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003311 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3312 "error_if_validators": (
3313 TosaErrorValidator.evWrongInputType,
3314 TosaErrorValidator.evWrongOutputType,
3315 TosaErrorValidator.evWrongInputList,
3316 TosaErrorValidator.evWrongOutputList,
3317 TosaErrorValidator.evInputZeroPointNotZero,
3318 TosaErrorValidator.evWeightZeroPointNotZero,
3319 TosaErrorValidator.evPadSmallerZero,
3320 TosaErrorValidator.evStrideSmallerOne,
3321 TosaErrorValidator.evDilationSmallerOne,
3322 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003323 TosaErrorValidator.evConvOutputShapeMismatch,
3324 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003325 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003326 "data_gen": {
3327 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3328 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003329 "template": True,
3330 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003331 # Templated operator. Filled in by createDynamicOpLists
3332 "conv3d_TEMPLATE": {
3333 "op": Op.CONV3D,
3334 "operands": (1, 2),
3335 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003336 "build_fcn": (
3337 build_conv3d,
3338 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003339 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003340 TosaArgGen.agConv,
3341 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003342 "qgen": TosaQuantGen.qgConv,
3343 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003344 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3345 "error_if_validators": (
3346 TosaErrorValidator.evWrongInputType,
3347 TosaErrorValidator.evWrongOutputType,
3348 TosaErrorValidator.evWrongInputList,
3349 TosaErrorValidator.evWrongOutputList,
3350 TosaErrorValidator.evInputZeroPointNotZero,
3351 TosaErrorValidator.evWeightZeroPointNotZero,
3352 TosaErrorValidator.evPadSmallerZero,
3353 TosaErrorValidator.evStrideSmallerOne,
3354 TosaErrorValidator.evDilationSmallerOne,
3355 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003356 TosaErrorValidator.evConvOutputShapeMismatch,
3357 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003358 ),
evacha0147ab1762024-01-29 13:23:23 +00003359 "data_gen": {
3360 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3361 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003362 "template": True,
3363 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003364 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003365 "depthwise_conv2d_TEMPLATE": {
3366 "op": Op.DEPTHWISE_CONV2D,
3367 "operands": (1, 2),
3368 "filter": [1, 1],
3369 "rank": (4, 4),
3370 "build_fcn": (
3371 build_depthwise_conv2d,
3372 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003373 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003374 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003375 ),
3376 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003377 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003378 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3379 "error_if_validators": (
3380 TosaErrorValidator.evWrongInputType,
3381 TosaErrorValidator.evWrongOutputType,
3382 TosaErrorValidator.evWrongInputList,
3383 TosaErrorValidator.evWrongOutputList,
3384 TosaErrorValidator.evInputZeroPointNotZero,
3385 TosaErrorValidator.evWeightZeroPointNotZero,
3386 TosaErrorValidator.evPadSmallerZero,
3387 TosaErrorValidator.evStrideSmallerOne,
3388 TosaErrorValidator.evDilationSmallerOne,
3389 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003390 TosaErrorValidator.evConvOutputShapeMismatch,
3391 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003392 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003393 "data_gen": {
3394 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3395 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003396 "template": True,
3397 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 "fully_connected": {
3399 "op": Op.FULLY_CONNECTED,
3400 "operands": (1, 2),
3401 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003402 "build_fcn": (
3403 build_fully_connected,
3404 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003405 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003406 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003407 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003408 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003409 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003410 "error_if_validators": (
3411 TosaErrorValidator.evInputZeroPointNotZero,
3412 TosaErrorValidator.evWeightZeroPointNotZero,
3413 TosaErrorValidator.evWrongRank,
3414 TosaErrorValidator.evWrongInputType,
3415 TosaErrorValidator.evWrongOutputType,
3416 TosaErrorValidator.evWrongInputList,
3417 TosaErrorValidator.evWrongOutputList,
3418 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003419 "data_gen": {
3420 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3421 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003422 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003423 "matmul": {
3424 "op": Op.MATMUL,
3425 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003426 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003427 "build_fcn": (
3428 build_matmul,
3429 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003430 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003431 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003432 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003434 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003435 "error_if_validators": (
3436 TosaErrorValidator.evInputZeroPointNotZero,
3437 TosaErrorValidator.evWrongRank,
3438 TosaErrorValidator.evWrongInputType,
3439 TosaErrorValidator.evWrongOutputType,
3440 TosaErrorValidator.evWrongInputList,
3441 TosaErrorValidator.evWrongOutputList,
3442 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003443 "data_gen": {
3444 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003445 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003446 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003447 "max_pool2d": {
3448 "op": Op.MAX_POOL2D,
3449 "operands": (1, 0),
3450 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003451 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003452 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003453 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003454 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003455 TosaArgGen.agPooling,
3456 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003457 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003458 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003459 "error_if_validators": (
3460 TosaErrorValidator.evKernelSmallerOne,
3461 TosaErrorValidator.evStrideSmallerOne,
3462 TosaErrorValidator.evPadSmallerZero,
3463 TosaErrorValidator.evWrongRank,
3464 TosaErrorValidator.evWrongInputType,
3465 TosaErrorValidator.evWrongOutputType,
3466 TosaErrorValidator.evWrongInputList,
3467 TosaErrorValidator.evWrongOutputList,
3468 TosaErrorValidator.evPadLargerEqualKernel,
3469 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003470 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003471 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003472 "data_gen": {
3473 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3474 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003475 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003476 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003477 "transpose_conv2d_TEMPLATE": {
3478 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003479 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003480 "rank": (4, 4),
3481 "build_fcn": (
3482 build_transpose_conv2d,
3483 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003484 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003485 TosaArgGen.agTransposeConv2D,
3486 ),
3487 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003488 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003489 "invalid_test_validators": (
3490 TosaInvalidValidator.ivHeightWidthInvalid,
3491 TosaInvalidValidator.ivNonPositiveOutputShape,
3492 ),
3493 "error_if_validators": (
3494 TosaErrorValidator.evWrongInputType,
3495 TosaErrorValidator.evWrongOutputType,
3496 TosaErrorValidator.evWrongInputList,
3497 TosaErrorValidator.evWrongOutputList,
3498 TosaErrorValidator.evInputZeroPointNotZero,
3499 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003500 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003501 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003502 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003503 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003504 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003505 "data_gen": {
3506 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3507 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003508 "template": True,
3509 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003510 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003511 "clamp": {
3512 "op": Op.CLAMP,
3513 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003514 "build_fcn": (
3515 build_clamp,
3516 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003517 TosaTensorValuesGen.tvgLazyGenDefault,
3518 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003520 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003521 "error_if_validators": (
3522 TosaErrorValidator.evMaxSmallerMin,
3523 TosaErrorValidator.evWrongInputType,
3524 TosaErrorValidator.evWrongOutputType,
3525 TosaErrorValidator.evWrongInputList,
3526 TosaErrorValidator.evWrongOutputList,
3527 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003528 "data_gen": {
3529 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3530 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003531 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003532 "sigmoid": {
3533 "op": Op.SIGMOID,
3534 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003535 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003536 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003537 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003538 TosaTensorValuesGen.tvgLazyGenDefault,
3539 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003540 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003541 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003542 "error_if_validators": (
3543 TosaErrorValidator.evWrongInputType,
3544 TosaErrorValidator.evWrongOutputType,
3545 TosaErrorValidator.evWrongInputList,
3546 TosaErrorValidator.evWrongOutputList,
3547 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003548 "data_gen": {
3549 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3550 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003551 },
3552 "tanh": {
3553 "op": Op.TANH,
3554 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003555 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003556 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003557 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003558 TosaTensorValuesGen.tvgLazyGenDefault,
3559 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003560 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003561 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003562 "error_if_validators": (
3563 TosaErrorValidator.evWrongInputType,
3564 TosaErrorValidator.evWrongOutputType,
3565 TosaErrorValidator.evWrongInputList,
3566 TosaErrorValidator.evWrongOutputList,
3567 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003568 "data_gen": {
3569 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3570 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003571 "compliance": {
3572 "abs_error_lower_bound": 0.5,
3573 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003574 },
Won Jeon78155c62023-06-10 00:20:04 +00003575 "erf": {
3576 "op": Op.ERF,
3577 "operands": (1, 0),
3578 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003579 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003580 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003581 TosaTensorValuesGen.tvgLazyGenDefault,
3582 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003583 ),
3584 "types": TYPE_FP,
3585 "error_if_validators": (
3586 TosaErrorValidator.evWrongInputType,
3587 TosaErrorValidator.evWrongOutputType,
3588 TosaErrorValidator.evWrongInputList,
3589 TosaErrorValidator.evWrongOutputList,
3590 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003591 "data_gen": {
3592 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3593 },
3594 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003595 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003596 # Elementwise Binary Operators
3597 "add": {
3598 "op": Op.ADD,
3599 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003600 "build_fcn": (
3601 build_binary_broadcast,
3602 TosaTensorGen.tgBroadcastFuzz,
3603 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003604 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003605 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003606 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003607 "error_if_validators": (
3608 TosaErrorValidator.evRankMismatch,
3609 TosaErrorValidator.evWrongInputType,
3610 TosaErrorValidator.evWrongOutputType,
3611 TosaErrorValidator.evWrongInputList,
3612 TosaErrorValidator.evWrongOutputList,
3613 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003614 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003615 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003616 "data_gen": {
3617 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3618 },
3619 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003620 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003621 "arithmetic_right_shift": {
3622 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3623 "operands": (2, 0),
3624 "build_fcn": (
3625 build_arithmetic_right_shift,
3626 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003627 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 TosaArgGen.agArithmeticRightShift,
3629 ),
3630 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003631 "error_if_validators": (
3632 TosaErrorValidator.evRankMismatch,
3633 TosaErrorValidator.evWrongInputType,
3634 TosaErrorValidator.evWrongOutputType,
3635 TosaErrorValidator.evWrongInputList,
3636 TosaErrorValidator.evWrongOutputList,
3637 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003638 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003639 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003640 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003641 "bitwise_and": {
3642 "op": Op.BITWISE_AND,
3643 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003644 "build_fcn": (
3645 build_binary_broadcast,
3646 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003647 TosaTensorValuesGen.tvgLazyGenDefault,
3648 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003649 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003650 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003651 "error_if_validators": (
3652 TosaErrorValidator.evRankMismatch,
3653 TosaErrorValidator.evWrongInputType,
3654 TosaErrorValidator.evWrongOutputType,
3655 TosaErrorValidator.evWrongInputList,
3656 TosaErrorValidator.evWrongOutputList,
3657 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003658 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003659 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003660 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003661 "bitwise_or": {
3662 "op": Op.BITWISE_OR,
3663 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003664 "build_fcn": (
3665 build_binary_broadcast,
3666 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003667 TosaTensorValuesGen.tvgLazyGenDefault,
3668 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003669 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003670 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003671 "error_if_validators": (
3672 TosaErrorValidator.evRankMismatch,
3673 TosaErrorValidator.evWrongInputType,
3674 TosaErrorValidator.evWrongOutputType,
3675 TosaErrorValidator.evWrongInputList,
3676 TosaErrorValidator.evWrongOutputList,
3677 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003678 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003680 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003681 "bitwise_xor": {
3682 "op": Op.BITWISE_XOR,
3683 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003684 "build_fcn": (
3685 build_binary_broadcast,
3686 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003687 TosaTensorValuesGen.tvgLazyGenDefault,
3688 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003689 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003690 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003691 "error_if_validators": (
3692 TosaErrorValidator.evRankMismatch,
3693 TosaErrorValidator.evWrongInputType,
3694 TosaErrorValidator.evWrongOutputType,
3695 TosaErrorValidator.evWrongInputList,
3696 TosaErrorValidator.evWrongOutputList,
3697 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003698 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003699 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003701 "intdiv": {
3702 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003703 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 "build_fcn": (
3705 build_binary_broadcast,
3706 TosaTensorGen.tgBroadcastFuzz,
3707 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003708 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003709 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003710 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003711 "error_if_validators": (
3712 TosaErrorValidator.evRankMismatch,
3713 TosaErrorValidator.evWrongInputType,
3714 TosaErrorValidator.evWrongOutputType,
3715 TosaErrorValidator.evWrongInputList,
3716 TosaErrorValidator.evWrongOutputList,
3717 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003718 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003719 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003720 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 "logical_and": {
3722 "op": Op.LOGICAL_AND,
3723 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003724 "build_fcn": (
3725 build_binary_broadcast,
3726 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003727 TosaTensorValuesGen.tvgLazyGenDefault,
3728 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003729 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003730 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003731 "error_if_validators": (
3732 TosaErrorValidator.evRankMismatch,
3733 TosaErrorValidator.evWrongInputType,
3734 TosaErrorValidator.evWrongOutputType,
3735 TosaErrorValidator.evWrongInputList,
3736 TosaErrorValidator.evWrongOutputList,
3737 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003738 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003739 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003740 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003741 "logical_left_shift": {
3742 "op": Op.LOGICAL_LEFT_SHIFT,
3743 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003744 "build_fcn": (
3745 build_binary_broadcast,
3746 TosaTensorGen.tgBroadcastFuzz,
3747 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003748 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003749 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003750 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003751 "error_if_validators": (
3752 TosaErrorValidator.evRankMismatch,
3753 TosaErrorValidator.evWrongInputType,
3754 TosaErrorValidator.evWrongOutputType,
3755 TosaErrorValidator.evWrongInputList,
3756 TosaErrorValidator.evWrongOutputList,
3757 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003758 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003760 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003761 "logical_right_shift": {
3762 "op": Op.LOGICAL_RIGHT_SHIFT,
3763 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003764 "build_fcn": (
3765 build_binary_broadcast,
3766 TosaTensorGen.tgBroadcastFuzz,
3767 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003768 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003769 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003770 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003771 "error_if_validators": (
3772 TosaErrorValidator.evRankMismatch,
3773 TosaErrorValidator.evWrongInputType,
3774 TosaErrorValidator.evWrongOutputType,
3775 TosaErrorValidator.evWrongInputList,
3776 TosaErrorValidator.evWrongOutputList,
3777 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003778 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "logical_or": {
3782 "op": Op.LOGICAL_OR,
3783 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 "build_fcn": (
3785 build_binary_broadcast,
3786 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003787 TosaTensorValuesGen.tvgLazyGenDefault,
3788 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003790 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 "error_if_validators": (
3792 TosaErrorValidator.evRankMismatch,
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003798 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003799 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003800 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003801 "logical_xor": {
3802 "op": Op.LOGICAL_XOR,
3803 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003804 "build_fcn": (
3805 build_binary_broadcast,
3806 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003807 TosaTensorValuesGen.tvgLazyGenDefault,
3808 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003809 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003810 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003811 "error_if_validators": (
3812 TosaErrorValidator.evRankMismatch,
3813 TosaErrorValidator.evWrongInputType,
3814 TosaErrorValidator.evWrongOutputType,
3815 TosaErrorValidator.evWrongInputList,
3816 TosaErrorValidator.evWrongOutputList,
3817 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003818 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003820 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003821 "maximum": {
3822 "op": Op.MAXIMUM,
3823 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003824 "build_fcn": (
3825 build_binary_broadcast,
3826 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003827 TosaTensorValuesGen.tvgLazyGenDefault,
3828 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003829 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003831 "error_if_validators": (
3832 TosaErrorValidator.evRankMismatch,
3833 TosaErrorValidator.evWrongInputType,
3834 TosaErrorValidator.evWrongOutputType,
3835 TosaErrorValidator.evWrongInputList,
3836 TosaErrorValidator.evWrongOutputList,
3837 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003838 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003840 "data_gen": {
3841 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3842 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003843 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003844 "minimum": {
3845 "op": Op.MINIMUM,
3846 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003847 "build_fcn": (
3848 build_binary_broadcast,
3849 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003850 TosaTensorValuesGen.tvgLazyGenDefault,
3851 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003852 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003854 "error_if_validators": (
3855 TosaErrorValidator.evRankMismatch,
3856 TosaErrorValidator.evWrongInputType,
3857 TosaErrorValidator.evWrongOutputType,
3858 TosaErrorValidator.evWrongInputList,
3859 TosaErrorValidator.evWrongOutputList,
3860 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003861 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003862 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003863 "data_gen": {
3864 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3865 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003866 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003867 "mul": {
3868 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003869 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003870 "build_fcn": (
3871 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003872 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003873 TosaTensorValuesGen.tvgMul,
3874 TosaArgGen.agMul,
3875 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003876 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003877 "error_if_validators": (
3878 TosaErrorValidator.evWrongInputType,
3879 TosaErrorValidator.evWrongOutputType,
3880 TosaErrorValidator.evWrongInputList,
3881 TosaErrorValidator.evWrongOutputList,
3882 TosaErrorValidator.evRankMismatch,
3883 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003884 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003885 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003886 "data_gen": {
3887 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3888 },
3889 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003891 "pow": {
3892 "op": Op.POW,
3893 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003894 "build_fcn": (
3895 build_binary_broadcast,
3896 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003897 TosaTensorValuesGen.tvgPow,
3898 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003899 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003901 "error_if_validators": (
3902 TosaErrorValidator.evRankMismatch,
3903 TosaErrorValidator.evWrongInputType,
3904 TosaErrorValidator.evWrongOutputType,
3905 TosaErrorValidator.evWrongInputList,
3906 TosaErrorValidator.evWrongOutputList,
3907 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003908 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003909 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003910 "data_gen": {
3911 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3912 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003913 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003914 "sub": {
3915 "op": Op.SUB,
3916 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003917 "build_fcn": (
3918 build_binary_broadcast,
3919 TosaTensorGen.tgBroadcastFuzz,
3920 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003921 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003922 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003923 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003924 "error_if_validators": (
3925 TosaErrorValidator.evRankMismatch,
3926 TosaErrorValidator.evWrongInputType,
3927 TosaErrorValidator.evWrongOutputType,
3928 TosaErrorValidator.evWrongInputList,
3929 TosaErrorValidator.evWrongOutputList,
3930 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003931 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003932 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003933 "data_gen": {
3934 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3935 },
3936 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003937 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003938 "table": {
3939 "op": Op.TABLE,
3940 # Use the automatic generation functions to create the input array
3941 # but create the table tensor in the build function, as it may be
3942 # a different type from the input
3943 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003944 "build_fcn": (
3945 build_table,
3946 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003947 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003948 TosaArgGen.agTable,
3949 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003950 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003951 "error_if_validators": (
3952 TosaErrorValidator.evWrongInputType,
3953 TosaErrorValidator.evWrongOutputType,
3954 TosaErrorValidator.evWrongInputList,
3955 TosaErrorValidator.evWrongOutputList,
3956 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003957 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003958 # Elementwise Unary operators
3959 "abs": {
3960 "op": Op.ABS,
3961 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003962 "build_fcn": (
3963 build_unary,
3964 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003965 TosaTensorValuesGen.tvgLazyGenDefault,
3966 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003967 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003968 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003969 "error_if_validators": (
3970 TosaErrorValidator.evWrongInputType,
3971 TosaErrorValidator.evWrongOutputType,
3972 TosaErrorValidator.evWrongInputList,
3973 TosaErrorValidator.evWrongOutputList,
3974 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003975 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00003976 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003977 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003978 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003979 "bitwise_not": {
3980 "op": Op.BITWISE_NOT,
3981 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003982 "build_fcn": (
3983 build_unary,
3984 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003985 TosaTensorValuesGen.tvgLazyGenDefault,
3986 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003987 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003988 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003989 "error_if_validators": (
3990 TosaErrorValidator.evWrongInputType,
3991 TosaErrorValidator.evWrongOutputType,
3992 TosaErrorValidator.evWrongInputList,
3993 TosaErrorValidator.evWrongOutputList,
3994 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003995 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003996 "ceil": {
3997 "op": Op.CEIL,
3998 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003999 "build_fcn": (
4000 build_unary,
4001 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004002 TosaTensorValuesGen.tvgLazyGenDefault,
4003 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004004 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004005 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004006 "error_if_validators": (
4007 TosaErrorValidator.evWrongInputType,
4008 TosaErrorValidator.evWrongOutputType,
4009 TosaErrorValidator.evWrongInputList,
4010 TosaErrorValidator.evWrongOutputList,
4011 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004012 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004013 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004014 },
4015 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004016 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004017 "clz": {
4018 "op": Op.CLZ,
4019 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004020 "build_fcn": (
4021 build_unary,
4022 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004023 TosaTensorValuesGen.tvgLazyGenDefault,
4024 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004025 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004026 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004027 "error_if_validators": (
4028 TosaErrorValidator.evWrongInputType,
4029 TosaErrorValidator.evWrongOutputType,
4030 TosaErrorValidator.evWrongInputList,
4031 TosaErrorValidator.evWrongOutputList,
4032 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004033 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004034 "cos": {
4035 "op": Op.COS,
4036 "operands": (1, 0),
4037 "build_fcn": (
4038 build_unary,
4039 TosaTensorGen.tgBasic,
4040 TosaTensorValuesGen.tvgLazyGenDefault,
4041 TosaArgGen.agNone,
4042 ),
4043 "types": TYPE_FP,
4044 "error_if_validators": (
4045 TosaErrorValidator.evWrongInputType,
4046 TosaErrorValidator.evWrongOutputType,
4047 TosaErrorValidator.evWrongInputList,
4048 TosaErrorValidator.evWrongOutputList,
4049 ),
4050 "data_gen": {
4051 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4052 },
4053 "compliance": {"abs_error_normal_divisor": 2},
4054 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004055 "exp": {
4056 "op": Op.EXP,
4057 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004058 "build_fcn": (
4059 build_unary,
4060 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004061 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004062 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004063 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004064 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004065 "error_if_validators": (
4066 TosaErrorValidator.evWrongInputType,
4067 TosaErrorValidator.evWrongOutputType,
4068 TosaErrorValidator.evWrongInputList,
4069 TosaErrorValidator.evWrongOutputList,
4070 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004071 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004072 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004073 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004074 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004075 "floor": {
4076 "op": Op.FLOOR,
4077 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004078 "build_fcn": (
4079 build_unary,
4080 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004081 TosaTensorValuesGen.tvgLazyGenDefault,
4082 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004083 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004084 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004085 "error_if_validators": (
4086 TosaErrorValidator.evWrongInputType,
4087 TosaErrorValidator.evWrongOutputType,
4088 TosaErrorValidator.evWrongInputList,
4089 TosaErrorValidator.evWrongOutputList,
4090 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004091 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004092 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004093 },
4094 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004095 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004096 "log": {
4097 "op": Op.LOG,
4098 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004099 "build_fcn": (
4100 build_unary,
4101 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004102 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004103 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004105 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004106 "error_if_validators": (
4107 TosaErrorValidator.evWrongInputType,
4108 TosaErrorValidator.evWrongOutputType,
4109 TosaErrorValidator.evWrongInputList,
4110 TosaErrorValidator.evWrongOutputList,
4111 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004112 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004113 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004114 },
4115 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004116 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004117 "logical_not": {
4118 "op": Op.LOGICAL_NOT,
4119 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004120 "build_fcn": (
4121 build_unary,
4122 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004123 TosaTensorValuesGen.tvgLazyGenDefault,
4124 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004125 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004126 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004127 "error_if_validators": (
4128 TosaErrorValidator.evWrongInputType,
4129 TosaErrorValidator.evWrongOutputType,
4130 TosaErrorValidator.evWrongInputList,
4131 TosaErrorValidator.evWrongOutputList,
4132 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004133 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004134 "negate": {
4135 "op": Op.NEGATE,
4136 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004137 "build_fcn": (
4138 build_unary,
4139 TosaTensorGen.tgBasic,
4140 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004141 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004142 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004143 "qgen": TosaQuantGen.qgUnary,
4144 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004145 "error_if_validators": (
4146 TosaErrorValidator.evInputZeroPointNotZero,
4147 TosaErrorValidator.evOutputZeroPointNotZero,
4148 TosaErrorValidator.evWrongInputType,
4149 TosaErrorValidator.evWrongOutputType,
4150 TosaErrorValidator.evWrongInputList,
4151 TosaErrorValidator.evWrongOutputList,
4152 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004153 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004154 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004155 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004156 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004157 "reciprocal": {
4158 "op": Op.RECIPROCAL,
4159 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004160 "build_fcn": (
4161 build_unary,
4162 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004163 TosaTensorValuesGen.tvgLazyGenDefault,
4164 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004165 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004166 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004167 "error_if_validators": (
4168 TosaErrorValidator.evWrongInputType,
4169 TosaErrorValidator.evWrongOutputType,
4170 TosaErrorValidator.evWrongInputList,
4171 TosaErrorValidator.evWrongOutputList,
4172 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004173 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004174 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004175 },
4176 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004177 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004178 "rsqrt": {
4179 "op": Op.RSQRT,
4180 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004181 "build_fcn": (
4182 build_unary,
4183 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004184 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004185 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004186 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004187 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 "error_if_validators": (
4189 TosaErrorValidator.evWrongInputType,
4190 TosaErrorValidator.evWrongOutputType,
4191 TosaErrorValidator.evWrongInputList,
4192 TosaErrorValidator.evWrongOutputList,
4193 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004194 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004195 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004196 },
4197 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004198 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004199 "sin": {
4200 "op": Op.SIN,
4201 "operands": (1, 0),
4202 "build_fcn": (
4203 build_unary,
4204 TosaTensorGen.tgBasic,
4205 TosaTensorValuesGen.tvgLazyGenDefault,
4206 TosaArgGen.agNone,
4207 ),
4208 "types": TYPE_FP,
4209 "error_if_validators": (
4210 TosaErrorValidator.evWrongInputType,
4211 TosaErrorValidator.evWrongOutputType,
4212 TosaErrorValidator.evWrongInputList,
4213 TosaErrorValidator.evWrongOutputList,
4214 ),
4215 "data_gen": {
4216 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4217 },
4218 "compliance": {"abs_error_normal_divisor": 2},
4219 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004220 # Elementwise Ternary operators
4221 "select": {
4222 "op": Op.SELECT,
4223 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004224 "build_fcn": (
4225 build_select,
4226 TosaTensorGen.tgBroadcastFuzz,
4227 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004228 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004229 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004230 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004231 "error_if_validators": (
4232 TosaErrorValidator.evRankMismatch,
4233 TosaErrorValidator.evWrongInputType,
4234 TosaErrorValidator.evWrongOutputType,
4235 TosaErrorValidator.evWrongInputList,
4236 TosaErrorValidator.evWrongOutputList,
4237 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004238 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004239 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004240 "data_gen": {
4241 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4242 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004243 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004244 # Comparison operators
4245 "equal": {
4246 "op": Op.EQUAL,
4247 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004248 "build_fcn": (
4249 build_comparison,
4250 TosaTensorGen.tgBroadcastFuzz,
4251 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004252 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004253 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004254 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004255 "error_if_validators": (
4256 TosaErrorValidator.evRankMismatch,
4257 TosaErrorValidator.evWrongInputType,
4258 TosaErrorValidator.evWrongOutputType,
4259 TosaErrorValidator.evWrongInputList,
4260 TosaErrorValidator.evWrongOutputList,
4261 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004262 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004263 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004264 "data_gen": {
4265 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4266 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004267 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004268 "greater_equal": {
4269 "op": Op.GREATER_EQUAL,
4270 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004271 "build_fcn": (
4272 build_comparison,
4273 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004274 TosaTensorValuesGen.tvgLazyGenDefault,
4275 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004276 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004277 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004278 "error_if_validators": (
4279 TosaErrorValidator.evRankMismatch,
4280 TosaErrorValidator.evWrongInputType,
4281 TosaErrorValidator.evWrongOutputType,
4282 TosaErrorValidator.evWrongInputList,
4283 TosaErrorValidator.evWrongOutputList,
4284 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004285 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004286 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004287 "data_gen": {
4288 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4289 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004290 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004291 "greater": {
4292 "op": Op.GREATER,
4293 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004294 "build_fcn": (
4295 build_comparison,
4296 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004297 TosaTensorValuesGen.tvgLazyGenDefault,
4298 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004299 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004300 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004301 "error_if_validators": (
4302 TosaErrorValidator.evRankMismatch,
4303 TosaErrorValidator.evWrongInputType,
4304 TosaErrorValidator.evWrongOutputType,
4305 TosaErrorValidator.evWrongInputList,
4306 TosaErrorValidator.evWrongOutputList,
4307 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004308 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004309 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004310 "data_gen": {
4311 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4312 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004313 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004314 # Reduction operators
4315 "reduce_all": {
4316 "op": Op.REDUCE_ALL,
4317 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004318 "build_fcn": (
4319 build_reduce,
4320 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004321 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004322 TosaArgGen.agAxis,
4323 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004324 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004325 "error_if_validators": (
4326 TosaErrorValidator.evAxisLargerRank,
4327 TosaErrorValidator.evAxisSmallerZero,
4328 TosaErrorValidator.evShapeOfAxisNotOne,
4329 TosaErrorValidator.evWrongInputType,
4330 TosaErrorValidator.evWrongOutputType,
4331 TosaErrorValidator.evWrongRank,
4332 TosaErrorValidator.evWrongInputList,
4333 TosaErrorValidator.evWrongOutputList,
4334 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004335 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004336 "reduce_any": {
4337 "op": Op.REDUCE_ANY,
4338 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004339 "build_fcn": (
4340 build_reduce,
4341 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004342 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004343 TosaArgGen.agAxis,
4344 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004345 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004346 "error_if_validators": (
4347 TosaErrorValidator.evAxisLargerRank,
4348 TosaErrorValidator.evAxisSmallerZero,
4349 TosaErrorValidator.evShapeOfAxisNotOne,
4350 TosaErrorValidator.evWrongInputType,
4351 TosaErrorValidator.evWrongOutputType,
4352 TosaErrorValidator.evWrongRank,
4353 TosaErrorValidator.evWrongInputList,
4354 TosaErrorValidator.evWrongOutputList,
4355 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004356 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004357 "reduce_max": {
4358 "op": Op.REDUCE_MAX,
4359 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004360 "build_fcn": (
4361 build_reduce,
4362 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004363 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004364 TosaArgGen.agAxis,
4365 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004366 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004367 "error_if_validators": (
4368 TosaErrorValidator.evAxisLargerRank,
4369 TosaErrorValidator.evAxisSmallerZero,
4370 TosaErrorValidator.evShapeOfAxisNotOne,
4371 TosaErrorValidator.evWrongInputType,
4372 TosaErrorValidator.evWrongOutputType,
4373 TosaErrorValidator.evWrongRank,
4374 TosaErrorValidator.evWrongInputList,
4375 TosaErrorValidator.evWrongOutputList,
4376 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004377 "data_gen": {
4378 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4379 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004380 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004381 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004382 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004383 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004384 "build_fcn": (
4385 build_reduce,
4386 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004387 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004388 TosaArgGen.agAxis,
4389 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004390 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004391 "error_if_validators": (
4392 TosaErrorValidator.evAxisLargerRank,
4393 TosaErrorValidator.evAxisSmallerZero,
4394 TosaErrorValidator.evShapeOfAxisNotOne,
4395 TosaErrorValidator.evWrongInputType,
4396 TosaErrorValidator.evWrongOutputType,
4397 TosaErrorValidator.evWrongRank,
4398 TosaErrorValidator.evWrongInputList,
4399 TosaErrorValidator.evWrongOutputList,
4400 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004401 "data_gen": {
4402 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4403 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004404 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004405 "reduce_product": {
4406 "op": Op.REDUCE_PRODUCT,
4407 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004408 "build_fcn": (
4409 build_reduce,
4410 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004411 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004412 TosaArgGen.agAxis,
4413 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004414 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004415 "error_if_validators": (
4416 TosaErrorValidator.evAxisLargerRank,
4417 TosaErrorValidator.evAxisSmallerZero,
4418 TosaErrorValidator.evShapeOfAxisNotOne,
4419 TosaErrorValidator.evWrongInputType,
4420 TosaErrorValidator.evWrongOutputType,
4421 TosaErrorValidator.evWrongRank,
4422 TosaErrorValidator.evWrongInputList,
4423 TosaErrorValidator.evWrongOutputList,
4424 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004425 "data_gen": {
4426 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4427 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004428 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004429 "reduce_sum": {
4430 "op": Op.REDUCE_SUM,
4431 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004432 "build_fcn": (
4433 build_reduce,
4434 TosaTensorGen.tgBasic,
4435 TosaTensorValuesGen.tvgReduceSum,
4436 TosaArgGen.agAxis,
4437 ),
James Ward24dbc422022-10-19 12:20:31 +01004438 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004439 "error_if_validators": (
4440 TosaErrorValidator.evAxisLargerRank,
4441 TosaErrorValidator.evAxisSmallerZero,
4442 TosaErrorValidator.evShapeOfAxisNotOne,
4443 TosaErrorValidator.evWrongInputType,
4444 TosaErrorValidator.evWrongOutputType,
4445 TosaErrorValidator.evWrongRank,
4446 TosaErrorValidator.evWrongInputList,
4447 TosaErrorValidator.evWrongOutputList,
4448 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004449 "data_gen": {
4450 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4451 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004452 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004453 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004454 "concat": {
4455 "op": Op.CONCAT,
4456 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004457 "build_fcn": (
4458 build_concat,
4459 TosaTensorGen.tgConcat,
4460 TosaTensorValuesGen.tvgConcat,
4461 TosaArgGen.agAxis,
4462 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004463 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004464 "error_if_validators": (
4465 TosaErrorValidator.evAxisLargerRank,
4466 TosaErrorValidator.evAxisSmallerZero,
4467 TosaErrorValidator.evConcatInputRankMismatch,
4468 TosaErrorValidator.evConcatShapeSumMismatch,
4469 TosaErrorValidator.evConcatInputDimMismatch,
4470 TosaErrorValidator.evWrongInputType,
4471 TosaErrorValidator.evWrongOutputType,
4472 TosaErrorValidator.evWrongOutputList,
4473 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004474 "data_gen": {
4475 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4476 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004477 },
4478 "pad": {
4479 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004480 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004481 "build_fcn": (
4482 build_pad,
4483 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004484 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004485 TosaArgGen.agPad,
4486 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004487 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004488 "error_if_validators": (
4489 TosaErrorValidator.evWrongInputType,
4490 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004491 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004492 TosaErrorValidator.evWrongOutputType,
4493 TosaErrorValidator.evWrongInputList,
4494 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004495 TosaErrorValidator.evRankMismatch,
4496 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004497 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004498 "data_gen": {
4499 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4500 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004501 },
Won Jeona21b2e82023-08-10 10:33:01 +00004502 "dim": {
4503 "op": Op.DIM,
4504 "operands": (1, 0),
4505 "build_fcn": (
4506 build_dim,
4507 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004508 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004509 TosaArgGen.agAxis,
4510 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004511 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004512 "error_if_validators": (
4513 TosaErrorValidator.evAxisLargerRank,
4514 TosaErrorValidator.evAxisSmallerZero,
4515 TosaErrorValidator.evWrongInputType,
4516 TosaErrorValidator.evWrongInputList,
4517 TosaErrorValidator.evWrongOutputList,
4518 TosaErrorValidator.evWrongRank,
4519 ),
4520 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004521 "reshape": {
4522 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004523 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004524 "build_fcn": (
4525 build_reshape,
4526 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004527 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004528 TosaArgGen.agReshape,
4529 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004530 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004531 "error_if_validators": (
4532 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4533 TosaErrorValidator.evWrongInputType,
4534 TosaErrorValidator.evWrongOutputType,
4535 TosaErrorValidator.evWrongInputList,
4536 TosaErrorValidator.evWrongOutputList,
4537 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004538 "data_gen": {
4539 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4540 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004541 },
4542 "reverse": {
4543 "op": Op.REVERSE,
4544 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004545 "build_fcn": (
4546 build_reverse,
4547 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004548 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004549 TosaArgGen.agAxis,
4550 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004551 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004552 "error_if_validators": (
4553 TosaErrorValidator.evAxisSmallerZero,
4554 TosaErrorValidator.evAxisLargerRank,
4555 TosaErrorValidator.evWrongInputType,
4556 TosaErrorValidator.evWrongOutputType,
4557 TosaErrorValidator.evWrongInputList,
4558 TosaErrorValidator.evWrongOutputList,
4559 ),
evacha0198477222024-01-26 12:25:32 +00004560 "data_gen": {
4561 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4562 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004563 },
4564 "slice": {
4565 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004566 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004567 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004568 "build_fcn": (
4569 build_slice,
4570 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004571 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004572 TosaArgGen.agSlice,
4573 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004574 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004575 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004576 # TODO Turn off these error categories for now as the reference
4577 # model cannot allocate memory space for empty tensor. We probably
4578 # can report an accurate error messege at the right place during
4579 # exeuction.
4580 # TosaErrorValidator.evStartSmallerZero,
4581 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004582 TosaErrorValidator.evStartSizeOutsideBounds,
4583 TosaErrorValidator.evSizeOutputShapeMismatch,
4584 TosaErrorValidator.evInputSizeStartLengthMismatch,
4585 TosaErrorValidator.evWrongRank,
4586 TosaErrorValidator.evWrongInputType,
4587 TosaErrorValidator.evWrongOutputType,
4588 TosaErrorValidator.evWrongInputList,
4589 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004590 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004591 ),
evacha017f7d4252024-01-24 12:08:09 +00004592 "data_gen": {
4593 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4594 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004595 },
4596 "tile": {
4597 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004598 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004599 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004600 "build_fcn": (
4601 build_tile,
4602 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004603 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004604 TosaArgGen.agTile,
4605 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004606 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004607 "error_if_validators": (
4608 TosaErrorValidator.evWrongInputType,
4609 TosaErrorValidator.evWrongOutputType,
4610 TosaErrorValidator.evWrongInputList,
4611 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004612 TosaErrorValidator.evRankMismatch,
4613 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004614 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004615 "data_gen": {
4616 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4617 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004618 },
4619 "transpose": {
4620 "op": Op.TRANSPOSE,
4621 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004622 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004623 "build_fcn": (
4624 build_transpose,
4625 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004626 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004627 TosaArgGen.agTranspose,
4628 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004629 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004630 "error_if_validators": (
4631 TosaErrorValidator.evIndexOutsideBounds,
4632 TosaErrorValidator.evIndexUsedTwice,
4633 TosaErrorValidator.evWrongInputType,
4634 TosaErrorValidator.evWrongOutputType,
4635 TosaErrorValidator.evWrongInputList,
4636 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004637 TosaErrorValidator.evWrongRank,
4638 TosaErrorValidator.evRankMismatch,
4639 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004640 ),
evacha0198477222024-01-26 12:25:32 +00004641 "data_gen": {
4642 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4643 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004644 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004645 # Data nodes
4646 "const": {
4647 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004648 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004649 "build_fcn": (
4650 build_const,
4651 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004652 TosaTensorValuesGen.tvgLazyGenDefault,
4653 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004654 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004655 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004656 "data_gen": {
4657 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4658 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004659 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004660 "identity": {
4661 "op": Op.IDENTITY,
4662 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004663 "build_fcn": (
4664 build_unary,
4665 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004666 TosaTensorValuesGen.tvgLazyGenDefault,
4667 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004668 ),
evacha011adff832024-03-06 17:33:44 +00004669 "types": TYPE_FIB + [DType.INT4, DType.INT48],
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004670 "data_gen": {
4671 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4672 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004673 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004674 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004675 "gather": {
4676 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004677 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004678 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004679 "build_fcn": (
4680 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004681 TosaTensorGen.tgGather,
4682 TosaTensorValuesGen.tvgGather,
4683 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004684 ),
James Ward24dbc422022-10-19 12:20:31 +01004685 "types": (
4686 DType.INT8,
4687 DType.INT16,
4688 DType.INT32,
4689 DType.FP16,
4690 DType.BF16,
4691 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004692 DType.FP8E4M3,
4693 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004694 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004695 "error_if_validators": (
4696 TosaErrorValidator.evWrongInputType,
4697 TosaErrorValidator.evWrongOutputType,
4698 TosaErrorValidator.evWrongInputList,
4699 TosaErrorValidator.evWrongOutputList,
4700 TosaErrorValidator.evWrongRank,
4701 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004702 "data_gen": {
4703 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4704 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004705 },
4706 "scatter": {
4707 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004708 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004709 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004710 "build_fcn": (
4711 build_scatter,
4712 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004713 TosaTensorValuesGen.tvgScatter,
4714 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004715 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004716 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004717 "error_if_validators": (
4718 TosaErrorValidator.evWrongInputType,
4719 TosaErrorValidator.evWrongOutputType,
4720 TosaErrorValidator.evWrongInputList,
4721 TosaErrorValidator.evWrongOutputList,
4722 TosaErrorValidator.evWrongRank,
4723 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004724 "data_gen": {
4725 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4726 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004727 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004728 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004729 "resize": {
4730 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004731 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004732 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004733 "build_fcn": (
4734 build_resize,
4735 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004736 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004737 TosaArgGen.agResize,
4738 ),
James Ward24dbc422022-10-19 12:20:31 +01004739 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004740 "invalid_test_validators": (
4741 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004742 ),
4743 "error_if_validators": (
4744 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004745 TosaErrorValidator.evScaleSmallerEqualZero,
4746 TosaErrorValidator.evScaleNLargerMax,
4747 TosaErrorValidator.evScaleDLargerMax,
4748 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004749 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004750 TosaErrorValidator.evBorderSmallerMin,
4751 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004752 TosaErrorValidator.evWrongInputType,
4753 TosaErrorValidator.evWrongOutputType,
4754 TosaErrorValidator.evWrongRank,
4755 TosaErrorValidator.evWrongInputList,
4756 TosaErrorValidator.evWrongOutputList,
4757 TosaErrorValidator.evBatchMismatch,
4758 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004759 TosaErrorValidator.evResizeOutputShapeMismatch,
4760 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004761 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004762 "data_gen": {
4763 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4764 },
4765 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004766 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004767 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004768 "cast": {
4769 "op": Op.CAST,
4770 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004771 "build_fcn": (
4772 build_cast,
4773 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004774 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004775 TosaArgGen.agCast,
4776 ),
James Ward8b390432022-08-12 20:48:56 +01004777 "types": (
4778 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004779 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004780 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004781 DType.INT8,
4782 DType.INT16,
4783 DType.INT32,
4784 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004785 DType.FP8E4M3,
4786 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004787 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004788 "error_if_validators": (
4789 TosaErrorValidator.evWrongInputType,
4790 TosaErrorValidator.evWrongOutputType,
4791 TosaErrorValidator.evWrongInputList,
4792 TosaErrorValidator.evWrongOutputList,
4793 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004794 "data_gen": {
4795 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4796 },
4797 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004798 },
4799 "rescale": {
4800 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004801 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004802 "build_fcn": (
4803 build_rescale,
4804 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004805 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004806 TosaArgGen.agRescale,
4807 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004808 "types": [
4809 DType.UINT8,
4810 DType.INT8,
4811 DType.INT16,
4812 DType.INT32,
4813 DType.INT48,
4814 DType.UINT16,
4815 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004816 "error_if_validators": (
4817 TosaErrorValidator.evInputZeroPointNotZero,
4818 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004819 TosaErrorValidator.evU16InputZeroPointNotValid,
4820 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004821 TosaErrorValidator.evScaleTrue,
4822 TosaErrorValidator.evScaleNotTrue,
4823 TosaErrorValidator.evWrongInputType,
4824 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004825 TosaErrorValidator.evWrongInputList,
4826 TosaErrorValidator.evWrongOutputList,
4827 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004828 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004829 # Custom
4830 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004831 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004832 # Two varients of cond_if, one that generates one of two constant tensors (no
4833 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4834 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004835 "cond_if_const": {
4836 "op": Op.COND_IF,
4837 "operands": (0, 2),
4838 "build_fcn": (
4839 build_cond_if_const,
4840 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004841 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004842 TosaArgGen.agCondIf,
4843 ),
4844 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004845 "error_if_validators": (
4846 TosaErrorValidator.evOutputListThenGraphMismatch,
4847 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004848 TosaErrorValidator.evCondIfCondNotMatchingBool,
4849 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004850 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004851 },
4852 "cond_if_binary": {
4853 "op": Op.COND_IF,
4854 "operands": (2, 0),
4855 "build_fcn": (
4856 build_cond_if_binary,
4857 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004858 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004859 TosaArgGen.agCondIf,
4860 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004861 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004862 "error_if_validators": (
4863 TosaErrorValidator.evInputListThenGraphMismatch,
4864 TosaErrorValidator.evInputListElseGraphMismatch,
4865 TosaErrorValidator.evOutputListThenGraphMismatch,
4866 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004867 TosaErrorValidator.evCondIfCondNotMatchingBool,
4868 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004869 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004870 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004871 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004872 "while_loop": {
4873 "op": Op.WHILE_LOOP,
4874 "operands": (0, 1),
4875 "build_fcn": (
4876 build_while_loop,
4877 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004878 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004879 TosaArgGen.agWhileLoop,
4880 ),
4881 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004882 "error_if_validators": (
4883 TosaErrorValidator.evInputListOutputListMismatch,
4884 TosaErrorValidator.evInputListCondGraphMismatch,
4885 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4886 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4887 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004888 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004889 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004890 },
Luke Hutton57287132023-02-06 14:54:18 +00004891 "fft2d": {
4892 "op": Op.FFT2D,
4893 "operands": (2, 0),
4894 "rank": (3, 3),
4895 "build_fcn": (
4896 build_fft2d,
4897 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004898 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004899 TosaArgGen.agFFT2d,
4900 ),
4901 "types": [DType.FP32],
4902 "error_if_validators": (
4903 TosaErrorValidator.evWrongInputType,
4904 TosaErrorValidator.evWrongOutputType,
4905 TosaErrorValidator.evWrongInputList,
4906 TosaErrorValidator.evWrongOutputList,
4907 TosaErrorValidator.evWrongRank,
4908 TosaErrorValidator.evBatchMismatch,
4909 TosaErrorValidator.evKernelNotPowerOfTwo,
4910 TosaErrorValidator.evFFTInputShapeMismatch,
4911 TosaErrorValidator.evFFTOutputShapeMismatch,
4912 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004913 "data_gen": {
4914 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4915 },
Luke Hutton57287132023-02-06 14:54:18 +00004916 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004917 "rfft2d": {
4918 "op": Op.RFFT2D,
4919 "operands": (1, 0),
4920 "rank": (3, 3),
4921 "build_fcn": (
4922 build_rfft2d,
4923 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004924 TosaTensorValuesGen.tvgLazyGenDefault,
4925 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004926 ),
4927 "types": [DType.FP32],
4928 "error_if_validators": (
4929 TosaErrorValidator.evWrongInputType,
4930 TosaErrorValidator.evWrongOutputType,
4931 TosaErrorValidator.evWrongInputList,
4932 TosaErrorValidator.evWrongOutputList,
4933 TosaErrorValidator.evWrongRank,
4934 TosaErrorValidator.evBatchMismatch,
4935 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004936 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004937 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004938 "data_gen": {
4939 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4940 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004941 },
Won Jeon74342e52024-01-09 00:34:40 +00004942 # Shape
4943 "add_shape": {
4944 "op": Op.ADD_SHAPE,
4945 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004946 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004947 "build_fcn": (
4948 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004949 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004950 TosaTensorValuesGen.tvgAddSub,
4951 TosaArgGen.agNone,
4952 ),
4953 "types": [DType.SHAPE],
4954 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4955 },
4956 "sub_shape": {
4957 "op": Op.SUB_SHAPE,
4958 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004959 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004960 "build_fcn": (
4961 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004962 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004963 TosaTensorValuesGen.tvgAddSub,
4964 TosaArgGen.agNone,
4965 ),
4966 "types": [DType.SHAPE],
4967 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4968 },
4969 "mul_shape": {
4970 "op": Op.MUL_SHAPE,
4971 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004972 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004973 "build_fcn": (
4974 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004975 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004976 TosaTensorValuesGen.tvgMul,
4977 TosaArgGen.agNone,
4978 ),
4979 "types": [DType.SHAPE],
4980 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4981 },
4982 "div_shape": {
4983 "op": Op.DIV_SHAPE,
4984 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004985 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004986 "build_fcn": (
4987 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004988 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004989 TosaTensorValuesGen.tvgIntDiv,
4990 TosaArgGen.agNone,
4991 ),
4992 "types": [DType.SHAPE],
4993 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4994 },
4995 "concat_shape": {
4996 "op": Op.CONCAT_SHAPE,
4997 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004998 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004999 "build_fcn": (
5000 build_concat,
5001 TosaTensorGen.tgConcat,
5002 TosaTensorValuesGen.tvgConcat,
5003 TosaArgGen.agNone,
5004 ),
5005 "types": [DType.SHAPE],
5006 "error_if_validators": (),
5007 },
5008 "const_shape": {
5009 "op": Op.CONST_SHAPE,
5010 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005011 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005012 "build_fcn": (
5013 build_const,
5014 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005015 TosaTensorValuesGen.tvgLazyGenDefault,
5016 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005017 ),
5018 "types": [DType.SHAPE],
5019 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005020 }
5021
Kevin Cheng550ccc52021-03-03 11:21:43 -08005022
Eric Kunzee5e26762020-10-13 16:11:07 -07005023class OutputShaper:
5024 # Methods in this class compute the expected output shape and datatype
5025 # for common classes of operations
5026 def __init__(self):
5027 pass
5028
5029 # These methods return arguments that can be used for
5030 # creating a new output tensor
5031 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005032 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5033 if error_name != ErrorIf.RankMismatch:
5034 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005035 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005036
5037 shape = []
5038 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005039 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005040 shape.append(b.shape[i])
5041 else:
5042 shape.append(a.shape[i])
5043
Jerry Ge135c9552023-05-23 20:59:32 +00005044 fuzz_idx = rng.integers(0, len(a.shape))
5045 if error_name == ErrorIf.DimensionMismatch:
5046 shape[fuzz_idx] += 1
5047
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005048 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005049 all_dtypes = [
5050 DType.INT8,
5051 DType.INT16,
5052 DType.INT32,
5053 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005054 DType.FP16,
5055 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005056 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005057 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005058 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5059 outputDType = rng.choice(wrong_dtypes)
5060 else:
5061 outputDType = a.dtype
5062
5063 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005064
5065 @staticmethod
5066 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005067 assert len(a.shape) == len(b.shape)
5068 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005069
5070 shape = []
5071 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005072 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005073 shape.append(a.shape[i])
5074
Kevin Cheng550ccc52021-03-03 11:21:43 -08005075 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005076
5077 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005078 def unaryOp(ser, rng, a, error_name=None):
5079 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005080 all_dtypes = [
5081 DType.INT8,
5082 DType.INT16,
5083 DType.INT32,
5084 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005085 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005086 DType.FP16,
5087 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005088 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005089 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5090 outputDType = rng.choice(wrong_dtypes)
5091 else:
5092 outputDType = a.dtype
5093
5094 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005095
5096 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005097 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005098 if error_name != ErrorIf.RankMismatch:
5099 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005100 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005101
5102 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005103 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005104 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005105 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5106 else:
5107 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005108
Jerry Ge135c9552023-05-23 20:59:32 +00005109 fuzz_idx = rng.integers(0, len(a.shape))
5110 if error_name == ErrorIf.DimensionMismatch:
5111 shape[fuzz_idx] += 1
5112
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005113 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005114 all_dtypes = [
5115 DType.INT8,
5116 DType.INT16,
5117 DType.INT32,
5118 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005119 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005120 DType.FP16,
5121 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005122 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005123 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5124 outputDType = rng.choice(wrong_dtypes)
5125 else:
5126 outputDType = a.dtype
5127
5128 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005129
5130 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005131 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005132 if error_name != ErrorIf.RankMismatch:
5133 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005134 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005135
5136 # Do broadcast
5137 shape = []
5138 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005139 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005140 shape.append(b.shape[i])
5141 else:
5142 shape.append(a.shape[i])
5143
Jerry Ge135c9552023-05-23 20:59:32 +00005144 fuzz_idx = rng.integers(0, len(a.shape))
5145 if error_name == ErrorIf.DimensionMismatch:
5146 shape[fuzz_idx] += 1
5147
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005148 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005149 wrong_dtypes = [
5150 DType.INT8,
5151 DType.INT16,
5152 DType.INT32,
5153 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005154 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005155 DType.FP16,
5156 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005157 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005158 outputDType = rng.choice(wrong_dtypes)
5159 else:
5160 outputDType = DType.BOOL
5161
5162 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005163
5164 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005165 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005166 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005167 if error_name not in [
5168 ErrorIf.AxisSmallerZero,
5169 ErrorIf.AxisLargerRank,
5170 ErrorIf.ShapeOfAxisNotOne,
5171 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005172 shape[axis] = 1
5173 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5174 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005175
Matthew Haddond6ce7252021-09-29 15:35:44 +01005176 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005177 all_dtypes = [
5178 DType.INT8,
5179 DType.INT16,
5180 DType.INT32,
5181 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005182 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005183 DType.FP16,
5184 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005185 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005186 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5187 outputDType = rng.choice(wrong_dtypes)
5188 else:
5189 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005190
Matthew Haddond6ce7252021-09-29 15:35:44 +01005191 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005192
5193 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005194 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005195 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005196
5197 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5198 del shape[axis]
5199
5200 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5201 remove = rng.choice([True, False])
5202 if remove and len(shape) > 1:
5203 del shape[0]
5204 else:
5205 shape.append(1)
5206 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5207 for i in range(len(shape)):
5208 shape[i] = shape[i] + rng.integers(1, 10)
5209
5210 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005211 all_dtypes = [
5212 DType.INT8,
5213 DType.INT16,
5214 DType.INT32,
5215 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005216 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005217 DType.FP16,
5218 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005219 DType.FP8E4M3,
5220 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005221 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005222 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5223 outputDType = rng.choice(wrong_dtypes)
5224 else:
5225 outputDType = DType.INT32
5226
5227 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005228
5229 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005230 def conv2dOp(
5231 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5232 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005233
5234 # IFM: NHWC
5235 # Filter: OHWI
5236 # OFM: NHWC
5237
Kevin Cheng550ccc52021-03-03 11:21:43 -08005238 h = (
5239 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005240 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005241 + padding[0]
5242 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005243 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005244 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005245
Kevin Cheng550ccc52021-03-03 11:21:43 -08005246 w = (
5247 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005248 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005249 + padding[2]
5250 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005251 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005252 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005253
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005254 if error_name == ErrorIf.ConvOutputShapeMismatch:
5255 choices = [1, 2, 3]
5256 change = rng.choice(choices)
5257 # increment in multiples of stride to not hit non-integer error case
5258 if change in [1, 3]:
5259 h = h + (rng.choice(choices) * strides[0])
5260 if change in [2, 3]:
5261 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005262
Eric Kunzee5e26762020-10-13 16:11:07 -07005263 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5264
James Ward8b390432022-08-12 20:48:56 +01005265 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005266 # Pick some potentially correct output dtype if input type is incorrect
5267 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005268 else:
James Ward8b390432022-08-12 20:48:56 +01005269 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005270
5271 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005272 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005273 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005274 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005275 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005276 else:
5277 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005278 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005279 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005280
Kevin Cheng550ccc52021-03-03 11:21:43 -08005281 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005282
5283 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005284 def conv3dOp(
5285 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5286 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005287
5288 # IFM: NDHWC
5289 # Filter: ODHWI
5290 # OFM: NDHWC
5291
5292 d = (
5293 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005294 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005295 + padding[0]
5296 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005297 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005298 ) // strides[0] + 1
5299
5300 h = (
5301 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005302 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005303 + padding[2]
5304 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005305 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005306 ) // strides[1] + 1
5307
5308 w = (
5309 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005310 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005311 + padding[4]
5312 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005313 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005314 ) // strides[2] + 1
5315
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005316 if error_name == ErrorIf.ConvOutputShapeMismatch:
5317 choices = [1, 2, 3, 4]
5318 change = rng.choice(choices)
5319 # increment in multiples of stride to not hit non-integer error case
5320 if change in [1, 4]:
5321 d = d + (rng.choice(choices) * strides[0])
5322 if change in [2, 4]:
5323 h = h + (rng.choice(choices) * strides[1])
5324 if change in [3, 4]:
5325 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005326
Kevin Cheng1533b852021-09-01 12:51:58 -07005327 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5328
James Ward8b390432022-08-12 20:48:56 +01005329 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005330 # Pick some potentially correct output dtype if input type is incorrect
5331 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005332 else:
James Ward8b390432022-08-12 20:48:56 +01005333 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005334
5335 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005336 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005337 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005338 else:
5339 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005340 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005341 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005342
5343 return ser.addOutput(ofm_shape, out_dtype)
5344
5345 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005346 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005347 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005348 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005349 # IFM: NHWC
5350 # Filter: HWCM
5351 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005352
Kevin Cheng550ccc52021-03-03 11:21:43 -08005353 h = (
5354 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005355 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005356 + padding[0]
5357 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005358 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005359 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005360
Kevin Cheng550ccc52021-03-03 11:21:43 -08005361 w = (
5362 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005363 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005364 + padding[2]
5365 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005366 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005367 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005368
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005369 if error_name == ErrorIf.ConvOutputShapeMismatch:
5370 choices = [1, 2, 3]
5371 change = rng.choice(choices)
5372 # increment in multiples of stride to not hit non-integer error case
5373 if change in [1, 3]:
5374 h = h + (rng.choice(choices) * strides[0])
5375 if change in [2, 3]:
5376 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005377
Eric Kunzee5e26762020-10-13 16:11:07 -07005378 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5379
James Ward8b390432022-08-12 20:48:56 +01005380 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005381 # Pick some potentially correct output dtype if input type is incorrect
5382 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005383 else:
James Ward8b390432022-08-12 20:48:56 +01005384 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005385
5386 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005387 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005388 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005389 else:
5390 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005391 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005392 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005393
Kevin Cheng550ccc52021-03-03 11:21:43 -08005394 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005395
5396 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005397 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005398 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005399 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005400 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005401 h = 1
5402 w = 1
5403 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005404 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5405 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005406
5407 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005408 choices = [1, 2, 3]
5409 change = rng.choice(choices)
5410 # increment in multiples of stride to not hit non-integer error case
5411 if change in [1, 3]:
5412 h = h + (rng.choice(choices) * stride[0])
5413 if change in [2, 3]:
5414 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005415 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005416
5417 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005418 all_dtypes = [
5419 DType.INT8,
5420 DType.INT16,
5421 DType.INT32,
5422 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005423 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005424 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005425 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005426 DType.FP8E4M3,
5427 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005428 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005429 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5430 outputDType = rng.choice(wrong_dtypes)
5431 else:
5432 outputDType = ifm.dtype
5433
5434 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005435
5436 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005437 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005438 # input: N, IC
5439 # filter: OC, IC
5440 # output: N, OC
5441
5442 output_shape = [input.shape[0], filter.shape[0]]
5443
James Ward8b390432022-08-12 20:48:56 +01005444 # Validated in arg_gen (also invalidated for ErrorIf)
5445 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005446
Kevin Cheng550ccc52021-03-03 11:21:43 -08005447 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005448
5449 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005450 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005451 # a: N, H, C
5452 # b: N, C, W
5453 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005454
Kevin Cheng2d60f002021-06-09 14:18:32 -07005455 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005456
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005457 if error_name == ErrorIf.WrongOutputType:
5458 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005459 incorrect_types = (
5460 DType.INT4,
5461 DType.INT8,
5462 DType.INT16,
5463 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005464 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005465 DType.FP16,
5466 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005467 DType.FP8E4M3,
5468 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005469 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005470 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005471 incorrect_types = (
5472 DType.INT4,
5473 DType.INT8,
5474 DType.INT16,
5475 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005476 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005477 DType.FP16,
5478 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005479 DType.FP8E4M3,
5480 DType.FP8E5M2,
5481 )
5482 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5483 incorrect_types = (
5484 DType.INT4,
5485 DType.INT8,
5486 DType.INT16,
5487 DType.INT32,
5488 DType.INT48,
5489 DType.FP32,
5490 DType.BF16,
5491 DType.FP8E4M3,
5492 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005493 )
James Ward24dbc422022-10-19 12:20:31 +01005494 elif (
5495 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5496 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005497 incorrect_types = (
5498 DType.INT4,
5499 DType.INT8,
5500 DType.INT16,
5501 DType.INT32,
5502 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005503 DType.FP8E4M3,
5504 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005505 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005506 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005507 elif error_name == ErrorIf.WrongInputType:
5508 # Pick some potentially correct output dtype if input type is incorrect
5509 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005510 else:
James Ward8b390432022-08-12 20:48:56 +01005511 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005512
Kevin Cheng550ccc52021-03-03 11:21:43 -08005513 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005514
5515 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005516 def concatOp(ser, rng, axis, inputs, error_name=None):
5517 input1 = inputs[0]
5518 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005519
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005520 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005521 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005522 if not (
5523 # unable to concat tensors of different ranks
5524 error_name == ErrorIf.ConcatInputRankMismatch
5525 # unable to concat tensors along an invalid axis
5526 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005527 ):
5528 for tensor in remaining_inputs:
5529 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005530
Matthew Haddon01c359d2021-10-15 16:30:48 +01005531 if error_name == ErrorIf.ConcatShapeSumMismatch:
5532 output_shape[axis] += rng.integers(5, 10)
5533
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005534 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005535 all_dtypes = {
5536 DType.INT8,
5537 DType.INT16,
5538 DType.INT32,
5539 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005540 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005541 DType.FP16,
5542 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005543 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005544 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5545 outputDType = rng.choice(wrong_dtypes)
5546 else:
5547 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005548
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005549 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005550
5551 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005552 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005553
5554 output_shape = a.shape.copy()
5555
5556 for i in range(len(output_shape)):
5557 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5558
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005559 if error_name == ErrorIf.PadOutputShapeMismatch:
5560 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005561 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005562 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005563 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005564
Matthew Haddone807aae2021-10-11 18:12:58 +01005565 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005566 all_dtypes = [
5567 DType.INT8,
5568 DType.INT16,
5569 DType.INT32,
5570 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005571 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005572 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005573 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005574 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005575 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5576 outputDType = rng.choice(wrong_dtypes)
5577 else:
5578 outputDType = a.dtype
5579
5580 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005581
5582 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005583 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005584 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005585
5586 if error_name == ErrorIf.WrongOutputType:
5587 all_dtypes = [
5588 DType.INT8,
5589 DType.INT16,
5590 DType.INT32,
5591 DType.INT48,
5592 DType.FP32,
5593 DType.FP16,
5594 DType.BF16,
5595 ]
5596 wrong_dtypes = list(set(all_dtypes))
5597 outputDType = rng.choice(wrong_dtypes)
5598 else:
5599 outputDType = DType.SHAPE
5600
5601 return ser.addOutput(output_shape, outputDType)
5602
5603 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005604 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005605 output_shape = shape.copy()
5606
Matthew Haddone807aae2021-10-11 18:12:58 +01005607 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5608 for i in range(len(output_shape)):
5609 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5610
5611 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005612 all_dtypes = [
5613 DType.INT8,
5614 DType.INT16,
5615 DType.INT32,
5616 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005617 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005618 DType.FP16,
5619 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005620 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005621 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5622 outputDType = rng.choice(wrong_dtypes)
5623 else:
5624 outputDType = a.dtype
5625
5626 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005627
5628 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005629 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005630
Matthew Haddone807aae2021-10-11 18:12:58 +01005631 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005632 all_dtypes = [
5633 DType.INT8,
5634 DType.INT16,
5635 DType.INT32,
5636 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005637 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005638 DType.FP16,
5639 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005640 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005641 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005642 outputDType = rng.choice(wrong_dtypes)
5643 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005644 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005645
Luke Huttona4e48ca2023-02-22 11:53:48 +00005646 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005647 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005648 for index in range(len(output_shape)):
5649 if output_shape[index] <= 2:
5650 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5651 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005652 output_shape[index] = output_shape[index] + rng.choice(
5653 [-2, -1, 1, 2]
5654 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005655 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5656 output_shape = input.shape.copy()
5657 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005658 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005659
5660 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005661
5662 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005663 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005664
5665 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005666 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005667
5668 for i in range(len(output_shape)):
5669 output_shape[i] = a.shape[i] * multiples[i]
5670
Luke Huttona4e48ca2023-02-22 11:53:48 +00005671 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005672 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005673
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005674 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005675 all_dtypes = [
5676 DType.INT8,
5677 DType.INT16,
5678 DType.INT32,
5679 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005680 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005681 DType.FP16,
5682 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005683 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005684 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5685 outputDType = rng.choice(wrong_dtypes)
5686 else:
5687 outputDType = a.dtype
5688
5689 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005690
5691 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005692 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005693 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005694
Kevin Cheng550ccc52021-03-03 11:21:43 -08005695 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005696
Luke Huttona4e48ca2023-02-22 11:53:48 +00005697 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005698 for i in range(len(output_shape)):
5699 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005700
Luke Huttona4e48ca2023-02-22 11:53:48 +00005701 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5702 for i in range(len(output_shape)):
5703 output_shape[i] += rng.integers(1, 10)
5704 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005705 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005706
Matthew Haddone807aae2021-10-11 18:12:58 +01005707 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005708 all_dtypes = [
5709 DType.INT8,
5710 DType.INT16,
5711 DType.INT32,
5712 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005713 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005714 DType.FP16,
5715 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005716 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005717 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5718 outputDType = rng.choice(wrong_dtypes)
5719 else:
5720 outputDType = a.dtype
5721
5722 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005723
5724 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005725 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005726 if error_name != ErrorIf.WrongRank:
5727 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005728 assert len(indices.shape) == 2
5729 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005730
Kevin Cheng77d0f762020-11-24 10:26:32 -08005731 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5732
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005733 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005734 all_dtypes = [
5735 DType.INT8,
5736 DType.INT16,
5737 DType.INT32,
5738 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005739 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005740 DType.FP16,
5741 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005742 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005743 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5744 outputDType = rng.choice(wrong_dtypes)
5745 else:
5746 outputDType = values.dtype
5747
5748 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005749
5750 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005751 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005752 if error_name != ErrorIf.WrongRank:
5753 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005754 assert len(indices.shape) == 2
5755 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005756 assert values_in.shape[0] == indices.shape[0] # N
5757 assert input.shape[1] == indices.shape[1] # W
5758 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005759
5760 output_shape = values_in.shape
5761
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005762 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005763 all_dtypes = [
5764 DType.INT8,
5765 DType.INT16,
5766 DType.INT32,
5767 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005768 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005769 DType.FP16,
5770 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005771 DType.FP8E4M3,
5772 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005773 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005774 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5775 outputDType = rng.choice(wrong_dtypes)
5776 else:
5777 outputDType = values_in.dtype
5778
5779 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005780
5781 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005782 def tableOp(ser, rng, input, error_name=None):
5783 # Same shape as the input, dtype dependent on input dtype
5784 if error_name != ErrorIf.WrongInputType:
5785 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005786 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005787 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005788 wrong_dtypes = [
5789 DType.INT8,
5790 DType.INT16,
5791 DType.INT32,
5792 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005793 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005794 DType.FP16,
5795 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005796 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005797 wrong_dtypes.remove(output_dtype)
5798 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005799 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005800
5801 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005802 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005803 serializer,
5804 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005805 input,
5806 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005807 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005808 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005809 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005810 input_dtype,
5811 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005812 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005813 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005814 # Calculate OH, OW
5815 scale_y_n = scale[0]
5816 scale_y_d = scale[1]
5817 scale_x_n = scale[2]
5818 scale_x_d = scale[3]
5819 if error_name == ErrorIf.ScaleSmallerEqualZero:
5820 scale_y_n = max(scale_y_n, 1)
5821 scale_y_d = max(scale_y_d, 1)
5822 scale_x_n = max(scale_x_n, 1)
5823 scale_x_d = max(scale_x_d, 1)
5824
5825 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5826 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5827
5828 if error_name is not None:
5829 # Make sure the output tensor is valid, which can occur when
5830 # scale, offset or border have been changed for ERROR_IFs
5831 oh = max(oh, 1)
5832 ow = max(ow, 1)
5833 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005834 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5835 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005836
5837 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5838 choices = [1, 2, 3]
5839 change = rng.choice(choices)
5840 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5841 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005842 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005843 oh -= scale_y_d
5844 assert oh > 0 # Should have been caught in agResize
5845 else:
5846 oh += scale_y_d
5847 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005848 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005849 ow -= scale_x_d
5850 assert ow > 0 # Should have been caught in agResize
5851 else:
5852 ow += scale_x_d
5853
Matthew Haddon848efb42021-09-09 12:30:53 +01005854 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005855 output_dims = [
5856 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005857 oh,
5858 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005859 input.shape[0],
5860 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005861 elif error_name == ErrorIf.BatchMismatch:
5862 output_dims = [
5863 input.shape[0] + rng.integers(1, 10),
5864 oh,
5865 ow,
5866 input.shape[3],
5867 ]
5868 elif error_name == ErrorIf.ChannelMismatch:
5869 output_dims = [
5870 input.shape[0],
5871 oh,
5872 ow,
5873 input.shape[3] + rng.integers(1, 10),
5874 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005875 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005876 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005877
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005878 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005879
5880 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005881 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005882 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005883
5884 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005885 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005886 if error_name == ErrorIf.ConvOutputShapeMismatch:
5887 choices = [1, 2, 3]
5888 change = rng.choice(choices)
5889 if change in [1, 3]:
5890 output_shape[1] = output_shape[1] + rng.choice(choices)
5891 if change in [2, 3]:
5892 output_shape[2] = output_shape[2] + rng.choice(choices)
5893
James Ward8b390432022-08-12 20:48:56 +01005894 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005895 # Pick some potentially correct output dtype if input type is incorrect
5896 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005897 else:
James Ward8b390432022-08-12 20:48:56 +01005898 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005899
5900 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005901 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005902 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005903 else:
5904 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005905 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005906 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005907
Kevin Cheng550ccc52021-03-03 11:21:43 -08005908 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005909
5910 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005911 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5912 outputs = []
5913
5914 assert ifm1.dtype == ifm2.dtype
5915 input_dtype = ifm1.dtype
5916
5917 if error_name != ErrorIf.FFTInputShapeMismatch:
5918 assert ifm1.shape == ifm2.shape
5919
5920 input_shape = ifm1.shape
5921 if error_name != ErrorIf.WrongRank:
5922 assert len(input_shape) == 3
5923
5924 output_shape = input_shape.copy()
5925 output_dtype = input_dtype
5926
5927 if error_name == ErrorIf.WrongOutputType:
5928 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005929 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005930 output_dtype = rng.choice(wrong_dtypes)
5931 elif error_name == ErrorIf.BatchMismatch:
5932 output_shape[0] += rng.integers(1, 10)
5933 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5934 modify_dim = rng.choice([1, 2])
5935 output_shape[modify_dim] += rng.integers(1, 10)
5936
5937 outputs.append(serializer.addOutput(output_shape, output_dtype))
5938 outputs.append(serializer.addOutput(output_shape, output_dtype))
5939 return outputs
5940
5941 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005942 def rfft2dOp(serializer, rng, value, error_name=None):
5943 outputs = []
5944
5945 input_shape = value.shape
5946 if error_name != ErrorIf.WrongRank:
5947 assert len(input_shape) == 3
5948
5949 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5950
5951 output_dtype = value.dtype
5952 if error_name == ErrorIf.WrongOutputType:
5953 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005954 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005955 output_dtype = rng.choice(wrong_dtypes)
5956 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005957 output_shape[0] += rng.integers(1, 10)
5958 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5959 modify_dim = rng.choice([1, 2])
5960 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005961
5962 outputs.append(serializer.addOutput(output_shape, output_dtype))
5963 outputs.append(serializer.addOutput(output_shape, output_dtype))
5964 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005965
5966 @staticmethod
5967 def addShapeOp(ser, rng, a, b, error_name=None):
5968 if error_name != ErrorIf.RankMismatch:
5969 assert len(a.shape) == len(b.shape)
5970 assert a.dtype == b.dtype
5971
5972 shape = []
5973 for i in range(len(a.shape)):
5974 shape.append(a.shape[i])
5975
5976 fuzz_idx = rng.integers(0, len(a.shape))
5977 if error_name == ErrorIf.DimensionMismatch:
5978 shape[fuzz_idx] += 1
5979
5980 if error_name == ErrorIf.WrongOutputType:
5981 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5982 outputDType = rng.choice(wrong_dtypes)
5983 else:
5984 outputDType = DType.SHAPE
5985 return ser.addOutput(shape, outputDType)