blob: 4309024c8db20743e5dec7ee9427a3186b36fc68 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
Won Jeon2c34b462024-02-06 18:37:00 +000079 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Won Jeon2c34b462024-02-06 18:37:00 +0000155 if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
evacha011adff832024-03-06 17:33:44 +0000194 elif dtype == DType.INT4:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000196 elif dtype == DType.INT8:
197 return np.int8(self.rng.integers(low=low, high=high, size=shape))
198 elif dtype == DType.UINT8:
199 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Jerry Ge20ab3df2024-01-26 16:56:55 +0000200 elif dtype == DType.INT16:
201 return np.int16(self.rng.integers(low=low, high=high, size=shape))
202 elif dtype == DType.UINT16:
203 return np.uint16(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000204 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100205 return np.int64(self.rng.integers(low=low, high=high, size=shape))
Won Jeon2c34b462024-02-06 18:37:00 +0000206 elif dtype in (
207 DType.FP16,
208 DType.BF16,
209 DType.FP32,
210 DType.FP8E4M3,
211 DType.FP8E5M2,
212 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
214
215 if dtype == DType.FP16:
216 return np.float16(f_tensor)
217 else:
218 f32_tensor = np.float32(f_tensor)
219 if dtype == DType.BF16:
220 # Floor the last 16 bits of each f32 value
221 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
Won Jeon2c34b462024-02-06 18:37:00 +0000222 elif dtype == DType.FP8E4M3:
223 return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
224 elif dtype == DType.FP8E5M2:
225 return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
Jeremy Johnson1271c442023-09-05 11:39:26 +0100226 else:
227 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700228 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100229 # All other integer types
230 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700231
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700233 placeholders = []
234
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 assert len(shape_list) == len(dtype_list)
236
Jeremy Johnson1271c442023-09-05 11:39:26 +0100237 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700238 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100239 if not self.args.lazy_data_gen:
240 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700241 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700242
243 return placeholders
244
Kevin Cheng989cb052021-04-28 16:29:44 -0700245 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700246 consts = []
247
Kevin Cheng989cb052021-04-28 16:29:44 -0700248 assert len(shape_list) == len(dtype_list)
249
Jeremy Johnson1271c442023-09-05 11:39:26 +0100250 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700251 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 if not self.args.lazy_data_gen:
253 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700254 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700255
256 return consts
257
258 def makeShape(self, rank):
259 if self.targetted_shape:
260 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800261 return np.int32(
262 self.rng.integers(
263 low=self.args.tensor_shape_range[0],
264 high=self.args.tensor_shape_range[1],
265 size=rank,
266 )
267 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700268
269 def setTargetShape(self, shape):
270 self.targetted_shape = shape
271
272 def randInt(self, low=0, high=256):
273 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
274
275 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100276 low, high = self.getDTypeRange(dtype)
277
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100278 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100280 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100281 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100282 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100283 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
284 return gtu.vect_f32_to_bf16(rand_f32)
Won Jeon2c34b462024-02-06 18:37:00 +0000285 elif dtype == DType.FP8E4M3:
286 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
287 return gtu.vect_f32_to_fp8e4m3(rand_f32)
288 elif dtype == DType.FP8E5M2:
289 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
290 return gtu.vect_f32_to_fp8e5m2(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700291 elif dtype == DType.BOOL:
292 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000293 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700294 # Special size
295 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700296
297 return np.int32(self.rng.integers(low, high, size=1))[0]
298
299 def shapeStr(self, shape):
300
301 sStr = []
302 # Convert to strings
303 for i in shape:
304 sStr.append(str(i))
305
Kevin Cheng550ccc52021-03-03 11:21:43 -0800306 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700307
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100308 def typeStr(self, dtype):
309 if isinstance(dtype, list) or isinstance(dtype, tuple):
310 assert len(dtype) >= 2
311 strs = [self.typeStr(t) for t in dtype]
312 # Limit types to the first 2 as the 3rd is the accumulator
313 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700314 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100315 if dtype in gtu.DTYPE_ATTRIBUTES:
316 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700317 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100318 raise Exception(
319 "Unknown dtype, cannot convert to string: {}".format(dtype)
320 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700321
Luke Hutton57287132023-02-06 14:54:18 +0000322 def constrictBatchSize(self, shape):
323 # Limit the batch size unless an explicit target shape set
324 if self.args.max_batch_size and not self.args.target_shapes:
325 shape[0] = min(shape[0], self.args.max_batch_size)
326 return shape
327
James Ward30124a82023-02-02 14:56:33 +0000328 def makeDimension(self):
329 return self.randInt(
330 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
331 )
332
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100333 def tensorComplianceMetaData(
334 self, op, inputType, argsDict, outputTensor, errorName
335 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000336 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
337 UNSUPPORTED_NON_FP32_INPUT_OPS = (
338 Op.MATMUL,
339 Op.CONV2D,
340 Op.FULLY_CONNECTED,
341 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000342 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000343 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000344 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100345 if (
346 errorName
347 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000348 or (
349 not gtu.dtypeIsSupportedByCompliance(inputType)
350 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
351 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100352 ):
353 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100354 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100355
Jeremy Johnson1271c442023-09-05 11:39:26 +0100356 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100357 compliance_tens = {
358 "mode": None,
359 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
360 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
361 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100362 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
363 mode = gtu.ComplianceMode.DOT_PRODUCT
364 compliance_tens["dot_product_info"] = {
365 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100366 "ks": int(argsDict["ksb"])
367 if "ksb" in argsDict
368 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100369 }
evacha019c96eef2024-02-07 11:21:55 +0000370 elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100371 mode = gtu.ComplianceMode.FP_SPECIAL
372 elif "compliance" in op and "ulp" in op["compliance"]:
373 mode = gtu.ComplianceMode.ULP
374 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000375 elif "compliance" in op and "relative" in op["compliance"]:
376 mode = gtu.ComplianceMode.RELATIVE
377 compliance_tens["relative_info"] = {
378 "max": argsDict["max_abs_value"],
379 "scale": op["compliance"]["relative"],
380 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100381 elif op["op"] == Op.REDUCE_PRODUCT:
382 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000383 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000384 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000385 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000386 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
387 compliance_tens["abs_error_info"] = {
388 "lower_bound": op["compliance"]["abs_error_lower_bound"]
389 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800390 elif op["op"] in (Op.SIN, Op.COS):
391 mode = gtu.ComplianceMode.ABS_ERROR
392 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
393 compliance_tens["abs_error_info"] = {
394 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
395 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100396 else:
397 mode = gtu.ComplianceMode.EXACT
398 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
399
400 return compliance_tens
401
402 # Build Op functions
403 # Create the output tensor (calling OutputShaper as needed)
404 # Do final tweaks to attributes (if necessary for errorIf)
405 # Add Op into graph
406 # Return resulting tensor information or BuildInfo
407
408 class BuildInfo:
409 """Enhanced build information containing result tensor and associated compliance dict."""
410
411 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000412 if isinstance(resultTensor, list):
413 assert complianceDict is None or isinstance(complianceDict, list)
414 self.resultTensorList = resultTensor
415 self.complianceDictList = complianceDict
416 else:
417 self.resultTensorList = [resultTensor]
418 if complianceDict is None:
419 self.complianceDictList = None
420 else:
421 self.complianceDictList = [complianceDict]
422
423 def getComplianceInfo(self):
424 if self.complianceDictList is None:
425 return None
426 else:
427 tens_dict = {}
428 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
429 if comp is not None:
430 tens_dict[tens.name] = comp
431
432 if tens_dict:
433 # Have some compliance data, so return the info
434 compliance = {
435 "version": "0.1",
436 "tensors": tens_dict,
437 }
438 else:
439 compliance = None
440 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700441
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000442 def build_unary(
443 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
444 ):
445 assert len(inputs) == 1
446 a = inputs[0]
447 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100448
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000449 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100450
451 # Ensure new output type has correct qinfo
452 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000453 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000454 qinfo = [
455 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000456 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000457 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100458
459 # Invalidate Input/Output list for error if checks.
460 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000461 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100462 pCount, cCount = op["operands"]
463 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000464 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
465 self, error_name, input_list, output_list
466 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100467
Les Bell729b0352021-11-24 10:28:21 +0000468 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100469 self.ser,
470 validator_fcns,
471 error_name,
472 op=op,
473 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000474 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000475 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000476 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100477 input_list=input_list,
478 output_list=output_list,
479 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000480 ):
481 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100482
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000483 attr = None
484 if op["op"] == Op.NEGATE:
485 attr = ts.TosaSerializerAttribute()
486 attr.NegateAttribute(qinfo[0], qinfo[1])
487
488 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000489
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000490 compliance = self.tensorComplianceMetaData(
491 op, a.dtype, args_dict, result_tensor, error_name
492 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000493 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700494
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000495 def build_binary_broadcast(
496 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
497 ):
498 assert len(inputs) == 2
499 a, b = inputs
500 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000501 self.ser, self.rng, a, b, error_name
502 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100503
504 # Invalidate Input/Output list for error if checks.
505 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000506 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100507 pCount, cCount = op["operands"]
508 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000509 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
510 self, error_name, input_list, output_list
511 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100512
Les Bell729b0352021-11-24 10:28:21 +0000513 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100514 self.ser,
515 validator_fcns,
516 error_name,
517 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000518 input1=a,
519 input2=b,
520 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000521 output_dtype=result_tensor.dtype,
522 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100523 input_list=input_list,
524 output_list=output_list,
525 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000526 ):
527 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100528
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000529 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000530
Jeremy Johnson9a758382023-11-07 16:27:35 +0000531 compliance = self.tensorComplianceMetaData(
532 op, a.dtype, args_dict, result_tensor, error_name
533 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000534
535 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700536
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700538 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700540 return result_tens
541
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000542 def build_arithmetic_right_shift(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000543 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000545 assert len(inputs) == 2
546 a, b = inputs
547 round = args_dict["round"]
548 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000549 self.ser, self.rng, a, b, error_name
550 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100551
552 # Invalidate Input/Output list for error if checks.
553 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000554 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100555 pCount, cCount = op["operands"]
556 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000557 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
558 self, error_name, input_list, output_list
559 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100560
Les Bell729b0352021-11-24 10:28:21 +0000561 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100562 self.ser,
563 validator_fcns,
564 error_name,
565 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000566 input1=a,
567 input2=b,
568 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000569 output_dtype=result_tensor.dtype,
570 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100571 input_list=input_list,
572 output_list=output_list,
573 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000574 ):
575 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800576
577 attr = ts.TosaSerializerAttribute()
578 attr.ArithmeticRightShiftAttribute(round)
579
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000580 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000581
582 compliance = self.tensorComplianceMetaData(
583 op, a.dtype, args_dict, result_tensor, error_name
584 )
585
586 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800587
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100588 def build_mul(
589 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
590 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000591 # Note that mul is binary operator but it has a shift value tensor
592 assert len(inputs) == 3
593 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100594
595 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000596 self.ser, self.rng, a, b, error_name
597 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700598
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100599 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100600 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100601 result_tensor.setDtype(DType.INT32)
602
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100603 if error_name == ErrorIf.WrongOutputType:
604 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
605 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100606 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100607
608 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000609 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100610 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100611 pCount, cCount = op["operands"]
612 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000613 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
614 self, error_name, input_list, output_list
615 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100616
Les Bell729b0352021-11-24 10:28:21 +0000617 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100618 self.ser,
619 validator_fcns,
620 error_name,
621 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000622 input1=a,
623 input2=b,
624 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100625 output_dtype=result_tensor.dtype,
626 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100627 input_list=input_list,
628 output_list=output_list,
629 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000630 ):
631 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700632
Jeremy Johnson0a042992024-02-28 13:20:05 +0000633 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100634
635 compliance = self.tensorComplianceMetaData(
636 op, a.dtype, args_dict, result_tensor, error_name
637 )
638
639 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700640
Jeremy Johnson587cc842024-02-08 11:45:44 +0000641 def build_table(
642 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
643 ):
644 assert len(inputs) == 1
645 a = inputs[0]
646 table = args_dict["table"]
647 result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700648
Kevin Chengfe392ce2021-10-18 21:51:55 +0000649 attr = ts.TosaSerializerAttribute()
650 attr.TableAttribute(table)
651
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100652 # Invalidate Input/Output list for error if checks.
653 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000654 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100655 pCount, cCount = op["operands"]
656 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000657 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
658 self, error_name, input_list, output_list
659 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100660
Les Bell729b0352021-11-24 10:28:21 +0000661 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100662 self.ser,
663 validator_fcns,
664 error_name,
665 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000666 input_shape=a.shape,
667 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000668 output_dtype=result_tensor.dtype,
669 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100670 input_list=input_list,
671 output_list=output_list,
672 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000673 ):
674 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100675
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000676 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700677
Jeremy Johnson587cc842024-02-08 11:45:44 +0000678 compliance = self.tensorComplianceMetaData(
679 op, a.dtype, args_dict, result_tensor, error_name
680 )
681
682 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700683
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000684 def build_select(
685 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
686 ):
687 assert len(inputs) == 3
688 cond, a, b = inputs
689
690 result_tensor = OutputShaper.selectOp(
691 self.ser, self.rng, cond, a, b, error_name
692 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100693
694 # Invalidate Input/Output list for error if checks.
695 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000696 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100697 pCount, cCount = op["operands"]
698 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000699 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
700 self, error_name, input_list, output_list
701 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100702
Les Bell729b0352021-11-24 10:28:21 +0000703 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100704 self.ser,
705 validator_fcns,
706 error_name,
707 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000708 input1=cond,
709 input2=a,
710 input3=b,
711 input_shape=a.shape,
712 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000713 output_dtype=result_tensor.dtype,
714 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100715 input_list=input_list,
716 output_list=output_list,
717 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000718 ):
719 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100720
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000721 self.ser.addOperator(
722 op["op"],
723 input_list,
724 output_list,
725 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000726 compliance = self.tensorComplianceMetaData(
727 op, a.dtype, args_dict, result_tensor, error_name
728 )
729
730 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700731
Jeremy Johnsona0150012023-11-15 15:52:06 +0000732 def build_comparison(
733 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
734 ):
735 assert len(inputs) == 2
736 a, b = inputs
737
738 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000739 self.ser, self.rng, a, b, error_name
740 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100741
742 # Invalidate Input/Output list for error if checks.
743 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000744 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100745 pCount, cCount = op["operands"]
746 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000747 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
748 self, error_name, input_list, output_list
749 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100750
Les Bell729b0352021-11-24 10:28:21 +0000751 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100752 self.ser,
753 validator_fcns,
754 error_name,
755 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000756 input1=a,
757 input2=b,
758 input_shape=a.shape,
759 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000760 output_shape=result_tensor.shape,
761 output_dtype=result_tensor.dtype,
762 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100763 input_list=input_list,
764 output_list=output_list,
765 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000766 ):
767 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100768
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000769 self.ser.addOperator(
770 op["op"],
771 input_list,
772 output_list,
773 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000774
775 compliance = self.tensorComplianceMetaData(
776 op, a.dtype, args_dict, result_tensor, error_name
777 )
778 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700779
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000780 def build_argmax(
781 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
782 ):
783 assert len(inputs) == 1
784 a = inputs[0]
785 axis = args_dict["axis"]
786 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100787
788 # Invalidate Input/Output list for error if checks.
789 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000790 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100791 pCount, cCount = op["operands"]
792 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
794 self, error_name, input_list, output_list
795 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100796
Les Bell729b0352021-11-24 10:28:21 +0000797 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100798 self.ser,
799 validator_fcns,
800 error_name,
801 op=op,
802 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 input_shape=a.shape,
804 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000805 output_shape=result_tensor.shape,
806 output_dtype=result_tensor.dtype,
807 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100808 input_list=input_list,
809 output_list=output_list,
810 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000811 ):
812 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700813
814 attr = ts.TosaSerializerAttribute()
815 attr.AxisAttribute(axis)
816
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000817 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000818
819 compliance = self.tensorComplianceMetaData(
820 op, inputs[0].dtype, args_dict, result_tensor, error_name
821 )
822 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700823
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000824 def build_pool2d(
825 self,
826 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100827 inputs,
828 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000829 validator_fcns=None,
830 error_name=None,
831 qinfo=None,
832 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100833 assert len(inputs) == 1
834 input = inputs[0]
835 # max_pool has no accum_dtype
836 accum_dtype = (
837 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
838 )
839 stride = args_dict["stride"]
840 pad = args_dict["pad"]
841 kernel = args_dict["kernel"]
842
Jeremy Johnson0601f802023-11-08 16:28:09 +0000843 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000844 self.ser, self.rng, input, kernel, stride, pad, error_name
845 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100846
847 # Ensure new output type has correct qinfo
848 if error_name == ErrorIf.WrongInputType:
849 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000850 qinfo = [
851 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000852 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000853 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100854
855 # Invalidate Input/Output list for error if checks.
856 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000857 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100858 pCount, cCount = op["operands"]
859 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000860 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
861 self, error_name, input_list, output_list
862 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100863
Les Bell729b0352021-11-24 10:28:21 +0000864 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100865 self.ser,
866 validator_fcns,
867 error_name,
868 op=op,
869 input_shape=input.shape,
870 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000871 output_shape=result_tensor.shape,
872 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000873 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100874 kernel=kernel,
875 stride=stride,
876 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000877 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000878 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100879 input_list=input_list,
880 output_list=output_list,
881 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000882 ):
883 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700884
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000885 if qinfo is None:
886 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700887
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000888 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100889 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000890
891 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700892
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100893 compliance = self.tensorComplianceMetaData(
894 op, inputs[0].dtype, args_dict, result_tensor, error_name
895 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100896
897 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100898
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000899 def build_conv2d(
900 self,
901 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100902 inputs,
903 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000904 validator_fcns=None,
905 error_name=None,
906 qinfo=None,
907 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100908 assert len(inputs) == 3
909 ifm, filter, bias = inputs
910 accum_dtype = args_dict["acc_type"]
911 strides = args_dict["stride"]
912 padding = args_dict["pad"]
913 dilations = args_dict["dilation"]
914
Kevin Cheng550ccc52021-03-03 11:21:43 -0800915 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100916 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100917 self.ser,
918 self.rng,
919 ifm,
920 filter,
921 accum_dtype,
922 strides,
923 padding,
924 dilations,
925 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000926 )
927
928 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000929 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
930 DType.INT8,
931 DType.UINT8,
932 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000933 qinfo = [
934 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100935 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000936 ]
Les Bell0e027d42021-11-09 14:42:14 +0000937
938 # Invalidate Input/Output list for error_if checks.
939 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100940 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000941 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000942 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
943 self, error_name, input_list, output_list
944 )
Les Bell0e027d42021-11-09 14:42:14 +0000945
Les Bell729b0352021-11-24 10:28:21 +0000946 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000947 self.ser,
948 validator_fcns,
949 error_name,
950 op=op,
951 input_dtype=ifm.dtype,
952 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100953 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000954 qinfo=qinfo,
955 input_list=input_list,
956 num_operands=num_operands,
957 output_list=output_list,
958 pad=padding,
959 stride=strides,
960 dilation=dilations,
961 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100962 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100963 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000964 ):
965 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700966
Tai Lyd3797f02023-11-15 23:06:19 +0000967 # TODO - Test local_bound, for now set local bound attribute to False
968 local_bound = False
969
Eric Kunzee5e26762020-10-13 16:11:07 -0700970 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000971 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700972
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000973 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100974
975 compliance = self.tensorComplianceMetaData(
976 op, ifm.dtype, args_dict, result_tensor, error_name
977 )
978
979 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700980
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000981 def build_conv3d(
982 self,
983 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100984 inputs,
985 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000986 validator_fcns=None,
987 error_name=None,
988 qinfo=None,
989 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100990 assert len(inputs) == 3
991 ifm, filter, bias = inputs
992 accum_dtype = args_dict["acc_type"]
993 strides = args_dict["stride"]
994 padding = args_dict["pad"]
995 dilations = args_dict["dilation"]
996
Kevin Cheng1533b852021-09-01 12:51:58 -0700997 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000998 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100999 self.ser,
1000 self.rng,
1001 ifm,
1002 filter,
1003 accum_dtype,
1004 strides,
1005 padding,
1006 dilations,
1007 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001008 )
1009
1010 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001011 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1012 DType.INT8,
1013 DType.UINT8,
1014 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001015 qinfo = [
1016 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +00001017 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001018 ]
Les Bell0e027d42021-11-09 14:42:14 +00001019
1020 # Invalidate Input/Output list for error_if checks.
1021 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +00001022 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001023 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001024 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1025 self, error_name, input_list, output_list
1026 )
Les Bell0e027d42021-11-09 14:42:14 +00001027
Les Bell729b0352021-11-24 10:28:21 +00001028 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001029 self.ser,
1030 validator_fcns,
1031 error_name,
1032 op=op,
1033 input_dtype=ifm.dtype,
1034 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001035 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001036 qinfo=qinfo,
1037 input_list=input_list,
1038 num_operands=num_operands,
1039 output_list=output_list,
1040 pad=padding,
1041 stride=strides,
1042 dilation=dilations,
1043 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001044 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001045 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001046 ):
1047 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001048
Tai Lyd3797f02023-11-15 23:06:19 +00001049 # TODO - Test local_bound, for now set local bound attribute to False
1050 local_bound = False
1051
Kevin Cheng1533b852021-09-01 12:51:58 -07001052 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001053 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001054
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001055 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001056
1057 compliance = self.tensorComplianceMetaData(
1058 op, ifm.dtype, args_dict, result_tensor, error_name
1059 )
1060
1061 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001062
Kevin Cheng550ccc52021-03-03 11:21:43 -08001063 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001064 self,
1065 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001066 inputs,
1067 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001068 validator_fcns=None,
1069 error_name=None,
1070 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001071 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001072 assert len(inputs) == 3
1073 ifm, filter, bias = inputs
1074 accum_dtype = args_dict["acc_type"]
1075 strides = args_dict["stride"]
1076 out_pad = args_dict["pad"]
1077 output_shape = args_dict["out_shape"]
1078
TatWai Chong24594f52022-06-08 00:48:04 -07001079 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001080 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001081 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001082 )
Les Bell0e027d42021-11-09 14:42:14 +00001083
1084 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001085 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1086 DType.INT8,
1087 DType.UINT8,
1088 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001089 qinfo = [
1090 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001091 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001092 ]
Les Bell0e027d42021-11-09 14:42:14 +00001093
1094 # Invalidate Input/Output list for error_if checks.
1095 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001096 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001097 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001098 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1099 self, error_name, input_list, output_list
1100 )
Les Bell0e027d42021-11-09 14:42:14 +00001101
Les Bell729b0352021-11-24 10:28:21 +00001102 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001103 self.ser,
1104 validator_fcns,
1105 error_name,
1106 op=op,
1107 input_dtype=ifm.dtype,
1108 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001109 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001110 qinfo=qinfo,
1111 input_list=input_list,
1112 num_operands=num_operands,
1113 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001114 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001115 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001116 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001117 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001118 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001119 ):
1120 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001121
Tai Lyd3797f02023-11-15 23:06:19 +00001122 # TODO - Test local_bound, for now set local bound attribute to False
1123 local_bound = False
1124
Eric Kunzee5e26762020-10-13 16:11:07 -07001125 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001126 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001127 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001128 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001129
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001130 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001131
1132 compliance = self.tensorComplianceMetaData(
1133 op, ifm.dtype, args_dict, result_tensor, error_name
1134 )
1135
1136 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001137
Kevin Cheng550ccc52021-03-03 11:21:43 -08001138 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001139 self,
1140 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001141 inputs,
1142 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 validator_fcns=None,
1144 error_name=None,
1145 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001146 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001147 assert len(inputs) == 3
1148 ifm, filter, bias = inputs
1149 accum_dtype = args_dict["acc_type"]
1150 strides = args_dict["stride"]
1151 padding = args_dict["pad"]
1152 dilations = args_dict["dilation"]
1153
Jeremy Johnson4f931302024-01-04 17:05:24 +00001154 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001155 self.ser,
1156 self.rng,
1157 ifm,
1158 filter,
1159 accum_dtype,
1160 strides,
1161 padding,
1162 dilations,
1163 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001164 )
1165
1166 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001167 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1168 DType.INT8,
1169 DType.UINT8,
1170 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001171 qinfo = [
1172 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001173 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001174 ]
Les Bell0e027d42021-11-09 14:42:14 +00001175
1176 # Invalidate Input/Output list for error_if checks.
1177 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001178 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001179 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1181 self, error_name, input_list, output_list
1182 )
Les Bell0e027d42021-11-09 14:42:14 +00001183
Les Bell729b0352021-11-24 10:28:21 +00001184 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001185 self.ser,
1186 validator_fcns,
1187 error_name,
1188 op=op,
1189 input_dtype=ifm.dtype,
1190 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001191 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001192 qinfo=qinfo,
1193 input_list=input_list,
1194 num_operands=num_operands,
1195 output_list=output_list,
1196 pad=padding,
1197 stride=strides,
1198 dilation=dilations,
1199 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001200 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001201 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001202 ):
1203 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001204
Tai Lyd3797f02023-11-15 23:06:19 +00001205 # TODO - Test local_bound, for now set local bound attribute to False
1206 local_bound = False
1207
Eric Kunzee5e26762020-10-13 16:11:07 -07001208 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001209 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001210
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001211 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001212
1213 compliance = self.tensorComplianceMetaData(
1214 op, ifm.dtype, args_dict, result_tensor, error_name
1215 )
1216
1217 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001218
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001219 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001220 self,
1221 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001222 inputs,
1223 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001224 validator_fcns=None,
1225 error_name=None,
1226 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001227 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001228 assert len(inputs) == 3
1229 ifm, filter, bias = inputs
1230 accum_dtype = args_dict["acc_type"]
1231
1232 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001233 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001234 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001235
1236 # Invalidate Input/Output list for error if checks.
1237 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001238 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001239 pCount, cCount = op["operands"]
1240 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001241 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1242 self, error_name, input_list, output_list
1243 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001244
Les Bell729b0352021-11-24 10:28:21 +00001245 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001246 self.ser,
1247 validator_fcns,
1248 error_name,
1249 op=op,
1250 input_shape=ifm.shape,
1251 input_dtype=ifm.dtype,
1252 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001253 output_shape=result_tensor.shape,
1254 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001255 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001256 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001257 input_list=input_list,
1258 output_list=output_list,
1259 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001260 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001261 ):
1262 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001263
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001264 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001265 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001266
1267 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001268
1269 compliance = self.tensorComplianceMetaData(
1270 op, ifm.dtype, args_dict, result_tensor, error_name
1271 )
1272
1273 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001274
James Ward8b390432022-08-12 20:48:56 +01001275 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001276 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001277 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001278 assert len(inputs) == 2
1279 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001280 accum_dtype = args_dict["acc_type"]
1281 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001282 self.ser, self.rng, a, b, accum_dtype, error_name
1283 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001284
1285 # Invalidate Input/Output list for error if checks.
1286 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001287 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001288 pCount, cCount = op["operands"]
1289 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001290 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1291 self, error_name, input_list, output_list
1292 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001293
Les Bell729b0352021-11-24 10:28:21 +00001294 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001295 self.ser,
1296 validator_fcns,
1297 error_name,
1298 op=op,
1299 input_shape=a.shape,
1300 input_dtype=a.dtype,
1301 input2_shape=b.shape,
1302 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001303 output_shape=result_tensor.shape,
1304 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001305 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001306 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001307 input_list=input_list,
1308 output_list=output_list,
1309 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001310 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001311 ):
1312 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001313
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001314 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001315 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001316
1317 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001318
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001319 compliance = self.tensorComplianceMetaData(
1320 op, a.dtype, args_dict, result_tensor, error_name
1321 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001322
1323 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001324
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001325 def build_reduce(
1326 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1327 ):
1328 assert len(inputs) == 1
1329 a = inputs[0]
1330 axis = args_dict["axis"]
1331 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001332
1333 # Invalidate Input/Output list for error if checks.
1334 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001335 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001336 pCount, cCount = op["operands"]
1337 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001338 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1339 self, error_name, input_list, output_list
1340 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001341
Les Bell729b0352021-11-24 10:28:21 +00001342 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001343 self.ser,
1344 validator_fcns,
1345 error_name,
1346 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001347 axis=axis,
1348 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001349 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001350 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001351 output_dtype=result_tensor.dtype,
1352 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001353 input_list=input_list,
1354 output_list=output_list,
1355 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001356 ):
1357 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001358
1359 attr = ts.TosaSerializerAttribute()
1360 attr.AxisAttribute(axis)
1361
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001362 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001363
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001364 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1365 # Number of products - needed for compliance
1366 args_dict["n"] = a.shape[axis]
1367
1368 compliance = self.tensorComplianceMetaData(
1369 op, a.dtype, args_dict, result_tensor, error_name
1370 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001371
1372 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001373
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001374 def build_clamp(
1375 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1376 ):
1377 assert len(inputs) == 1
1378 a = inputs[0]
1379
1380 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001381
Jeremy Johnson18e26662021-07-22 16:15:29 +01001382 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001383
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001384 if error_name == ErrorIf.MaxSmallerMin:
1385 # Make sure the numbers are different to invoke this error
1386 while v[0] == v[1]:
1387 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1388 max_val = min(v)
1389 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001390 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001391 max_val = max(v)
1392 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001394 # Invalidate Input/Output list for error if checks.
1395 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001396 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001397 pCount, cCount = op["operands"]
1398 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001399 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1400 self, error_name, input_list, output_list
1401 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402
Les Bell729b0352021-11-24 10:28:21 +00001403 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001404 self.ser,
1405 validator_fcns,
1406 error_name,
1407 op=op,
1408 max_val=max_val,
1409 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001410 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001411 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001412 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001413 output_dtype=result_tensor.dtype,
1414 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001415 input_list=input_list,
1416 output_list=output_list,
1417 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001418 ):
1419 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420
1421 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001422 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1423 if a.dtype == DType.FP16:
1424 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1425 min_val = min_val.astype(np.float32)
1426 max_val = max_val.astype(np.float32)
1427
1428 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001429 elif a.dtype in (DType.INT8, DType.INT16):
James Ward34071252022-12-07 15:48:47 +00001430 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Won Jeon2c34b462024-02-06 18:37:00 +00001431 else:
1432 # to avoid internal error for incorrect input types
1433 attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001434
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001435 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001436
1437 compliance = self.tensorComplianceMetaData(
1438 op, a.dtype, args_dict, result_tensor, error_name
1439 )
1440
1441 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001442
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001443 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1444 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001445 attr = ts.TosaSerializerAttribute()
1446
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001447 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001448
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001449 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001450 return result_tens
1451
1452 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001453 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1454 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001455
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001456 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001457 return result_tens
1458
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001459 def build_activation(
1460 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1461 ):
1462 assert len(inputs) == 1
1463 a = inputs[0]
1464
1465 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001466
1467 # Invalidate Input/Output list for error if checks.
1468 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001469 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001470 pCount, cCount = op["operands"]
1471 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001472 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1473 self, error_name, input_list, output_list
1474 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001475
Les Bell729b0352021-11-24 10:28:21 +00001476 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001477 self.ser,
1478 validator_fcns,
1479 error_name,
1480 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001481 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001482 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001483 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001484 output_dtype=result_tensor.dtype,
1485 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001486 input_list=input_list,
1487 output_list=output_list,
1488 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001489 ):
1490 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001491
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001492 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001493
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001494 compliance = self.tensorComplianceMetaData(
1495 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001496 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001497
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001498 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001499
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001500 def build_concat(
1501 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1502 ):
Won Jeon74342e52024-01-09 00:34:40 +00001503 if op["op"] == Op.CONCAT_SHAPE:
1504 axis = 0
1505 else:
1506 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001507 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001508 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001509
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001510 result_tensor = OutputShaper.concatOp(
1511 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001512 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001513
Matthew Haddon818ab902021-07-27 09:12:49 +01001514 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001515 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001516 input_tensor_names.append(tensor.name)
1517
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001518 # Invalidate Input/Output list for error if checks.
1519 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001520 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001521 pCount, cCount = op["operands"]
1522 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001523 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1524 self, error_name, input_list, output_list
1525 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001526
Les Bell729b0352021-11-24 10:28:21 +00001527 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001528 self.ser,
1529 validator_fcns,
1530 error_name,
1531 op=op,
1532 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001533 input_shape=inputs[0].shape,
1534 output_shape=result_tensor.shape,
1535 input_dtype=inputs[0].dtype,
1536 output_dtype=result_tensor.dtype,
1537 inputs=inputs,
1538 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001539 input_list=input_list,
1540 output_list=output_list,
1541 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001542 ):
1543 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001544
Won Jeon74342e52024-01-09 00:34:40 +00001545 if op["op"] == Op.CONCAT:
1546 attr = ts.TosaSerializerAttribute()
1547 attr.AxisAttribute(axis)
1548 else:
1549 assert op["op"] == Op.CONCAT_SHAPE
1550 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001551 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001552
1553 compliance = self.tensorComplianceMetaData(
1554 op, inputs[0].dtype, args_dict, result_tensor, error_name
1555 )
1556
1557 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001558
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001559 def build_pad(
1560 self,
1561 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001562 inputs,
1563 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 validator_fcns=None,
1565 error_name=None,
1566 qinfo=None,
1567 ):
Tai Lye095da72024-01-25 22:00:18 +00001568 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001569 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001570 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001571 padding = args_dict["pad"]
1572 pad_const_int = args_dict["pad_const_int"]
1573 pad_const_float = args_dict["pad_const_fp"]
1574
1575 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001576
Tai Lye095da72024-01-25 22:00:18 +00001577 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001578 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001579 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001580
Matthew Haddone807aae2021-10-11 18:12:58 +01001581 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001582 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001583 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001584 pCount, cCount = op["operands"]
1585 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001586 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1587 self, error_name, input_list, output_list
1588 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001589
Les Bell729b0352021-11-24 10:28:21 +00001590 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001591 self.ser,
1592 validator_fcns,
1593 error_name,
1594 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001595 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001596 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001597 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001598 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001599 pad=padding,
1600 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001601 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001602 input_list=input_list,
1603 output_list=output_list,
1604 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001605 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001606 ):
1607 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001608
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001609 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001610
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001611 compliance = self.tensorComplianceMetaData(
1612 op, a.dtype, args_dict, result_tensor, error_name
1613 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001614
1615 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001616
Won Jeona21b2e82023-08-10 10:33:01 +00001617 def build_dim(
1618 self,
1619 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001620 inputs,
1621 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001622 validator_fcns=None,
1623 error_name=None,
1624 qinfo=None,
1625 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001626 assert len(inputs) == 1
1627 a = inputs[0]
1628 axis = args_dict["axis"]
1629 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001630
1631 # Invalidate Input/Output list for error if checks.
1632 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001633 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001634 pCount, cCount = op["operands"]
1635 num_operands = pCount + cCount
1636 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1637 self, error_name, input_list, output_list
1638 )
1639
1640 if not TosaErrorValidator.evValidateErrorIfs(
1641 self.ser,
1642 validator_fcns,
1643 error_name,
1644 op=op,
1645 axis=axis,
1646 input_shape=a.shape,
1647 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001648 output_shape=result_tensor.shape,
1649 output_dtype=result_tensor.dtype,
1650 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001651 input_list=input_list,
1652 output_list=output_list,
1653 num_operands=num_operands,
1654 ):
1655 return None
1656
1657 attr = ts.TosaSerializerAttribute()
1658 attr.AxisAttribute(axis)
1659
1660 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001661 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001662
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001663 def build_reshape(
1664 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1665 ):
Tai Ly8690a082023-12-18 20:40:24 +00001666 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001667 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001668 shape = inputs[1]
1669 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001670 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001671 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001672 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001673
1674 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001675 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001676 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 pCount, cCount = op["operands"]
1678 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001679 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1680 self, error_name, input_list, output_list
1681 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001682
Les Bell729b0352021-11-24 10:28:21 +00001683 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001684 self.ser,
1685 validator_fcns,
1686 error_name,
1687 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001689 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001690 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001691 output_dtype=result_tensor.dtype,
1692 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001693 input_list=input_list,
1694 output_list=output_list,
1695 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001696 ):
1697 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001698
Tai Ly8690a082023-12-18 20:40:24 +00001699 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001700
1701 compliance = self.tensorComplianceMetaData(
1702 op, a.dtype, args_dict, result_tensor, error_name
1703 )
1704
1705 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001706
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001707 def build_reverse(
1708 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1709 ):
1710 assert len(inputs) == 1
1711 a = inputs[0]
1712 axis = args_dict["axis"]
1713 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001714
1715 # Invalidate Input/Output list for error if checks.
1716 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001717 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001718 pCount, cCount = op["operands"]
1719 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001720 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1721 self, error_name, input_list, output_list
1722 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001723
Les Bell729b0352021-11-24 10:28:21 +00001724 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725 self.ser,
1726 validator_fcns,
1727 error_name,
1728 op=op,
1729 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001730 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001731 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001732 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001733 output_dtype=result_tensor.dtype,
1734 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001735 input_list=input_list,
1736 output_list=output_list,
1737 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001738 ):
1739 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001740
1741 attr = ts.TosaSerializerAttribute()
1742 attr.AxisAttribute(axis)
1743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001744 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001745 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001746
evacha0198477222024-01-26 12:25:32 +00001747 def build_transpose(
1748 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1749 ):
1750 assert len(inputs) == 1
1751 a = inputs[0]
1752 perms = args_dict["perms"]
1753
1754 result_tensor = OutputShaper.transposeOp(
1755 self.ser, self.rng, a, perms, error_name
1756 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001757
Kevin Chengfe392ce2021-10-18 21:51:55 +00001758 attr = ts.TosaSerializerAttribute()
1759 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001760
Matthew Haddone807aae2021-10-11 18:12:58 +01001761 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001762 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001763 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001764 pCount, cCount = op["operands"]
1765 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001766 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1767 self, error_name, input_list, output_list
1768 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001769
Les Bell729b0352021-11-24 10:28:21 +00001770 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001771 self.ser,
1772 validator_fcns,
1773 error_name,
1774 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001776 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001777 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001778 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001779 output_dtype=result_tensor.dtype,
1780 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001781 input_list=input_list,
1782 output_list=output_list,
1783 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001784 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001785 ):
1786 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001787
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001788 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001789
1790 compliance = self.tensorComplianceMetaData(
1791 op, a.dtype, args_dict, result_tensor, error_name
1792 )
1793
1794 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001795
evacha017f7d4252024-01-24 12:08:09 +00001796 def build_slice(
1797 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1798 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001799 assert len(inputs) == 3
1800 a, start_var, size_var = inputs
1801 start_const = args_dict["start"]
1802 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001803
1804 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001805 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001806 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001807
1808 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001809 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001810 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001811 pCount, cCount = op["operands"]
1812 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1814 self, error_name, input_list, output_list
1815 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001816
Les Bell729b0352021-11-24 10:28:21 +00001817 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001818 self.ser,
1819 validator_fcns,
1820 error_name,
1821 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001823 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001825 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001826 start=start_const,
1827 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001828 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001829 input_list=input_list,
1830 output_list=output_list,
1831 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001832 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001833 ):
1834 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001835
Tai Ly8ead6c42024-02-14 22:35:44 +00001836 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001837
1838 compliance = self.tensorComplianceMetaData(
1839 op, a.dtype, args_dict, result_tensor, error_name
1840 )
1841
1842 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001843
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001844 def build_tile(
1845 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1846 ):
Tai Ly8690a082023-12-18 20:40:24 +00001847 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001848 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001849 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001850 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001851 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001852 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001853 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001854
1855 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001856 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001857 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001858 pCount, cCount = op["operands"]
1859 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001860 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1861 self, error_name, input_list, output_list
1862 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001863
Les Bell729b0352021-11-24 10:28:21 +00001864 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001865 self.ser,
1866 validator_fcns,
1867 error_name,
1868 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001869 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001870 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001871 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001872 output_dtype=result_tensor.dtype,
1873 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001874 input_list=input_list,
1875 output_list=output_list,
1876 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001877 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001878 ):
1879 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001880
Tai Ly8690a082023-12-18 20:40:24 +00001881 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001882
1883 compliance = self.tensorComplianceMetaData(
1884 op, a.dtype, args_dict, result_tensor, error_name
1885 )
1886
1887 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001888
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001889 def build_gather(
1890 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1891 ):
1892 assert len(inputs) == 2
1893 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001894
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001895 result_tensor = OutputShaper.gatherOp(
1896 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001897 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001898
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001899 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001900 input_list = [values.name, indices.name]
1901 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001902 pCount, cCount = op["operands"]
1903 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001904 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1905 self, error_name, input_list, output_list
1906 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001907
Les Bell729b0352021-11-24 10:28:21 +00001908 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001909 self.ser,
1910 validator_fcns,
1911 error_name,
1912 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001913 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001914 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001915 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001916 output_dtype=result_tensor.dtype,
1917 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001918 input_list=input_list,
1919 output_list=output_list,
1920 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001921 ):
1922 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001923
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001924 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001925
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001926 compliance = self.tensorComplianceMetaData(
1927 op, values.dtype, args_dict, result_tensor, error_name
1928 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001929
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001930 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001931
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001932 def build_scatter(
1933 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1934 ):
1935 assert len(inputs) == 3
1936 values_in, indices, input = inputs
1937 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001938 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001939 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001940
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001941 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001942 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001943 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001944 pCount, cCount = op["operands"]
1945 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001946 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1947 self, error_name, input_list, output_list
1948 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001949
Les Bell729b0352021-11-24 10:28:21 +00001950 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001951 self.ser,
1952 validator_fcns,
1953 error_name,
1954 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001955 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001956 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001957 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001958 output_dtype=result_tensor.dtype,
1959 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001960 input_list=input_list,
1961 output_list=output_list,
1962 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001963 ):
1964 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001965
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001966 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001967
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001968 compliance = self.tensorComplianceMetaData(
1969 op, values_in.dtype, args_dict, result_tensor, error_name
1970 )
1971
1972 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001973
Kevin Cheng550ccc52021-03-03 11:21:43 -08001974 def build_resize(
1975 self,
1976 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001977 inputs,
1978 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001979 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001980 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001981 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001982 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001983 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001984 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001985 scale_input = inputs[1]
1986 offset_input = inputs[2]
1987 border_input = inputs[3]
1988
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001989 mode = args_dict["mode"]
1990 scale = args_dict["scale"]
1991 offset = args_dict["offset"]
1992 border = args_dict["border"]
1993 output_dtype = args_dict["output_dtype"]
1994
1995 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08001996 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001997 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 input,
1999 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002000 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002001 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002002 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002003 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002004 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002005 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002006 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002007
Matthew Haddon848efb42021-09-09 12:30:53 +01002008 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002009 input_list = [
2010 input.name,
2011 scale_input.name,
2012 offset_input.name,
2013 border_input.name,
2014 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002015 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002016 pCount, cCount = op["operands"]
2017 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002018 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2019 self, error_name, input_list, output_list
2020 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002021
Les Bell729b0352021-11-24 10:28:21 +00002022 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002023 self.ser,
2024 validator_fcns,
2025 error_name,
2026 op=op,
2027 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002028 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002029 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002030 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002031 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002032 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002033 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002034 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002035 input_list=input_list,
2036 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002037 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002038 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002039 ):
2040 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002041
Eric Kunzee5e26762020-10-13 16:11:07 -07002042 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002043 # write empty scale/offset/border into ResizeAttribute
2044 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002045 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002046
2047 compliance = self.tensorComplianceMetaData(
2048 op, input.dtype, args_dict, result_tensor, error_name
2049 )
2050
2051 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002052
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002053 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2054 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2055 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002056 self.ser.addOperator(
2057 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2058 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002059 return result_tens
2060
evacha0198477222024-01-26 12:25:32 +00002061 def build_const(
2062 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2063 ):
2064 assert len(inputs) == 1
2065 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002066 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002067
2068 compliance = self.tensorComplianceMetaData(
2069 op, val.dtype, args_dict, val, error_name
2070 )
2071
2072 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002073
2074 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002075 def build_cast(
2076 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2077 ):
2078 assert len(inputs) == 1
2079 val = inputs[0]
2080 out_dtype = args_dict["out_type"]
2081
2082 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002083 self.ser, self.rng, val, out_dtype, error_name
2084 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002085
2086 # Invalidate Input/Output list for error if checks.
2087 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002088 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002089 pCount, cCount = op["operands"]
2090 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002091 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2092 self, error_name, input_list, output_list
2093 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002094
Les Bell729b0352021-11-24 10:28:21 +00002095 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002096 self.ser,
2097 validator_fcns,
2098 error_name,
2099 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002100 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002101 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002102 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002103 output_dtype=result_tensor.dtype,
2104 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002105 input_list=input_list,
2106 output_list=output_list,
2107 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002108 ):
2109 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002110
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002111 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002112
2113 compliance = self.tensorComplianceMetaData(
2114 op, val.dtype, args_dict, result_tensor, error_name
2115 )
2116
2117 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002118
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002119 def build_rescale(
2120 self,
2121 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002122 inputs,
2123 args_dict,
2124 validator_fcns=None,
2125 error_name=None,
2126 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002127 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002128 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002129 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002130 multiplier_val = inputs[1]
2131 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002132 out_dtype = args_dict["output_dtype"]
2133 scale32 = args_dict["scale"]
2134 double_round = args_dict["double_round"]
2135 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002136 shift_arr = args_dict["shift"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002137
2138 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002139 self.ser, self.rng, val, out_dtype, error_name
2140 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002141
2142 if per_channel:
2143 nc = val.shape[-1]
2144 else:
2145 nc = 1
2146
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002147 in_type_width = gtu.dtypeWidth(val.dtype)
2148 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002149
Tai Ly8690a082023-12-18 20:40:24 +00002150 input_unsigned = False
2151 output_unsigned = False
2152
Kevin Cheng3a478572021-01-22 17:21:02 -08002153 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002154 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002155 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002156 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002157 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002158 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002159 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002160 elif error_name in [
2161 ErrorIf.InputZeroPointNotZero,
2162 ErrorIf.U16InputZeroPointNotValid,
2163 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002164 input_zp = self.randInt(-128, 128)
2165 if input_zp == 0:
2166 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002167 in_type_width += 1
2168 elif val.dtype == DType.UINT16:
2169 # Must come after ErrorIf.U16InputZeroPointNotValid check
2170 input_zp = self.rng.choice([0, 32768])
2171 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002172 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002173 else:
2174 input_zp = 0
2175
Kevin Cheng3a478572021-01-22 17:21:02 -08002176 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002177 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002178 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002179 elif out_dtype == DType.UINT8:
2180 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002181 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002182 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002183 elif error_name in [
2184 ErrorIf.OutputZeroPointNotZero,
2185 ErrorIf.U16OutputZeroPointNotValid,
2186 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002187 output_zp = self.randInt(-128, 128)
2188 if output_zp == 0:
2189 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002190 out_type_width += 1
2191 elif out_dtype == DType.UINT16:
2192 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2193 output_zp = self.rng.choice([0, 32768])
2194 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002195 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002196 else:
2197 output_zp = 0
2198
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002199 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2200 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002201
2202 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002203 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2204 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002205
Kevin Cheng550ccc52021-03-03 11:21:43 -08002206 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002207 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002208 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002209 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002210 assert val.placeholderFilename
2211 values = np.load(
2212 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2213 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002214 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2215 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2216 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002217 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2218 # Check we can safely convert to the expected dtype
2219 assert (
2220 val_adj.all() >= np.iinfo(values.dtype).min
2221 and val_adj.all() <= np.iinfo(values.dtype).max
2222 )
2223
2224 # Force casting to output datatype
2225 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2226
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002227 if not np.all(np.array_equal(values, val_adj)):
2228 # Values changed so overwrite file with new values
2229 np.save(
2230 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2231 val_adj,
2232 False,
2233 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002234
Matthew Haddonc2025212021-10-08 21:21:05 +01002235 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002236 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002237 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002238 pCount, cCount = op["operands"]
2239 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002240 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2241 self, error_name, input_list, output_list
2242 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002243
2244 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002245 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002246 self.ser,
2247 validator_fcns,
2248 error_name,
2249 op=op,
2250 input_dtype=val.dtype,
2251 output_dtype=out_dtype,
2252 input_shape=val.shape,
2253 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002254 scale32=scale32,
2255 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002256 input_list=input_list,
2257 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002258 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002259 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002260 ):
2261 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002262
Eric Kunzee5e26762020-10-13 16:11:07 -07002263 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002264 attr.RescaleAttribute(
2265 input_zp,
2266 output_zp,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002267 [],
2268 [],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002269 scale32,
2270 double_round,
2271 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002272 input_unsigned,
2273 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002274 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002275
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002276 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002277
2278 compliance = self.tensorComplianceMetaData(
2279 op, val.dtype, args_dict, result_tensor, error_name
2280 )
2281
2282 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002283
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002284 def _get_condition_tensor(self, op, cond, error_name):
2285 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002286 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002287 else:
2288 cond_type = DType.BOOL
2289 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2290 choice = self.rng.choice([1, 2])
2291 if choice == 1:
2292 cond_shape = [2]
2293 else:
2294 cond_shape = [1, 2]
2295 else:
2296 # Must be of size 1 (rank 0)
2297 cond_shape = []
2298 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2299 return cond_tens
2300
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002301 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002302 self,
2303 op,
2304 inputs,
2305 args_dict,
2306 validator_fcns=None,
2307 error_name=None,
2308 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002309 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002310 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002311 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002312 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002313 assert len(inputs) == 2
2314 then_tens, else_tens = inputs
2315
2316 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002317
2318 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002319 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002320
2321 # Make then/else tensors
2322 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002323
Jeremy Johnson587cc842024-02-08 11:45:44 +00002324 dtype = DType.INT32
2325
Matthew Haddon630c17c2021-10-14 15:05:41 +01002326 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002327 if error_name in [
2328 ErrorIf.CondIfOutputListThenGraphMismatch,
2329 ErrorIf.CondIfOutputListElseGraphMismatch,
2330 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002331 incorrect_shape = deepcopy(then_tens.shape)
2332 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002333 incorrect_shape[i] += (
2334 self.rng.choice([-3, -2, 2, 3])
2335 if incorrect_shape[i] > 3
2336 else self.rng.choice([1, 2, 4])
2337 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002338 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2339
Jeremy Johnson18e26662021-07-22 16:15:29 +01002340 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2341 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002342
2343 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002344 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002345
2346 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002347 then_block = "THEN_BLOCK"
2348 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002349 attr = ts.TosaSerializerAttribute()
2350 attr.CondIfAttribute(then_block, else_block)
2351
2352 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002353 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002354
Jerry Ge9e94af82022-10-27 09:57:00 -07002355 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002356 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002357 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002358 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002359 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002360 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002361 self.ser.addOutputTensor(then_tens)
2362
Jerry Ge9e94af82022-10-27 09:57:00 -07002363 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002364 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002365 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002366 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002367 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002368 self.ser.addOutputTensor(else_tens)
2369
Les Bell729b0352021-11-24 10:28:21 +00002370 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002371 self.ser,
2372 validator_fcns,
2373 error_name,
2374 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002375 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002376 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002377 ):
2378 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002379
Jeremy Johnson587cc842024-02-08 11:45:44 +00002380 compliance = self.tensorComplianceMetaData(
2381 op, dtype, args_dict, result_tensor, error_name
2382 )
2383
2384 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002385
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002386 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002387 self,
2388 op,
2389 inputs,
2390 args_dict,
2391 validator_fcns=None,
2392 error_name=None,
2393 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002394 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002395 # For cond_if with a binary op in the then/else blocks, take a and b and
2396 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002397 assert len(inputs) == 2
2398 a, b = inputs
2399
2400 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002401
2402 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002403 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002404
Jeremy Johnson587cc842024-02-08 11:45:44 +00002405 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002406
2407 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002408 then_block = "THEN_BLOCK"
2409 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002410 attr = ts.TosaSerializerAttribute()
2411 attr.CondIfAttribute(then_block, else_block)
2412
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002413 if error_name in [
2414 ErrorIf.CondIfInputListThenGraphMismatch,
2415 ErrorIf.CondIfInputListElseGraphMismatch,
2416 ErrorIf.CondIfOutputListElseGraphMismatch,
2417 ErrorIf.CondIfOutputListThenGraphMismatch,
2418 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002419 incorrect_shape = a.shape.copy()
2420 for i in range(len(incorrect_shape)):
2421 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2422 incorrect_block_input = deepcopy(a)
2423 incorrect_block_input.shape = incorrect_shape
2424
Eric Kunzee5e26762020-10-13 16:11:07 -07002425 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002426 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002427 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002428 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002429
James Ward24dbc422022-10-19 12:20:31 +01002430 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002431 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002432 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002433 then_op, else_op = (
2434 self.TOSA_OP_LIST["logical_right_shift"],
2435 self.TOSA_OP_LIST["logical_left_shift"],
2436 )
Les Bell6040b4d2021-10-11 12:50:31 +01002437 else:
2438 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002439
Jeremy Johnson587cc842024-02-08 11:45:44 +00002440 # Determine the element-wise binary operation that compliance will need to
2441 # check the results of
2442 compliance_op = then_op if cond else else_op
2443
2444 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002445 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002446 if (
2447 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2448 and block == then_block
2449 ) or (
2450 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2451 and block == else_block
2452 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002453 self.ser.addInputTensor(incorrect_block_input)
2454 self.ser.addInputTensor(b)
2455 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002456 elif (
2457 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2458 and block == then_block
2459 ) or (
2460 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2461 and block == else_block
2462 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002463 self.ser.addInputTensor(a)
2464 self.ser.addInputTensor(b)
2465 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2466 else:
2467 self.ser.addInputTensor(a)
2468 self.ser.addInputTensor(b)
2469 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002470 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002471
Les Bell729b0352021-11-24 10:28:21 +00002472 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002473 self.ser,
2474 validator_fcns,
2475 error_name,
2476 op=op,
2477 a=a,
2478 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002479 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002480 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002481 ):
2482 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002483
Jeremy Johnson587cc842024-02-08 11:45:44 +00002484 compliance = self.tensorComplianceMetaData(
2485 compliance_op, a.dtype, args_dict, result_tensor, error_name
2486 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002487
Jeremy Johnson587cc842024-02-08 11:45:44 +00002488 return TosaTestGen.BuildInfo(result_tensor, compliance)
2489
2490 def build_while_loop(
2491 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2492 ):
2493 assert len(inputs) == 1
2494 a = inputs[0]
2495 iter_val = args_dict["iterations"]
2496
Kevin Cheng550ccc52021-03-03 11:21:43 -08002497 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002498
Kevin Cheng550ccc52021-03-03 11:21:43 -08002499 cond_block = "COND_BLOCK"
2500 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002501
2502 attr = ts.TosaSerializerAttribute()
2503 attr.WhileLoopAttribute(cond_block, body_block)
2504
2505 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002506 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002507 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002508 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002509
2510 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002511 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2512 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002513 if error_name == ErrorIf.InputListOutputListMismatch:
2514 incorrect_acc = deepcopy(acc)
2515 for i in range(len(incorrect_acc.shape)):
2516 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2517 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2518 else:
2519 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002520
2521 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002522 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002523 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002524 [iter.name, a.name, acc.name],
2525 [iter_out.name, a_out.name, acc_out.name],
2526 attr,
2527 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002528 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002529
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002530 if error_name in [
2531 ErrorIf.InputListCondGraphMismatch,
2532 ErrorIf.InputListBodyGraphInputMismatch,
2533 ErrorIf.InputListBodyGraphOutputMismatch,
2534 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002535 incorrect_iter = deepcopy(iter)
2536 for i in range(len(incorrect_iter.shape)):
2537 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2538 if len(incorrect_iter.shape) == 0:
2539 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2540
2541 incorrect_acc = deepcopy(acc)
2542 for i in range(len(incorrect_acc.shape)):
2543 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2544
Eric Kunzee5e26762020-10-13 16:11:07 -07002545 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002546 self.ser.addBasicBlock(cond_block)
2547
Matthew Haddon630c17c2021-10-14 15:05:41 +01002548 if error_name == ErrorIf.InputListCondGraphMismatch:
2549 self.ser.addInputTensor(incorrect_iter)
2550 self.ser.addInputTensor(a)
2551 self.ser.addInputTensor(incorrect_acc)
2552 else:
2553 self.ser.addInputTensor(iter)
2554 self.ser.addInputTensor(a)
2555 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002556 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002557
2558 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002559 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002560 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002561 cond_type = DType.BOOL
2562 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2563 choice = self.rng.choice([1, 2])
2564 if choice == 1:
2565 cond_shape = [3]
2566 else:
2567 cond_shape = [1, 2]
2568 else:
2569 cond_shape = []
2570 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002571
Kevin Cheng550ccc52021-03-03 11:21:43 -08002572 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002573
2574 # BODY block (input: a, acc, iter, output: a, acc, iter)
2575 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002576 self.ser.addBasicBlock(body_block)
2577
Matthew Haddon630c17c2021-10-14 15:05:41 +01002578 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2579 self.ser.addInputTensor(incorrect_iter)
2580 self.ser.addInputTensor(a)
2581 self.ser.addInputTensor(incorrect_acc)
2582 else:
2583 self.ser.addInputTensor(iter)
2584 self.ser.addInputTensor(a)
2585 self.ser.addInputTensor(acc)
2586
Kevin Cheng550ccc52021-03-03 11:21:43 -08002587 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002588
2589 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002590 iter_body_out = self.ser.addIntermediate(
2591 incorrect_iter.shape, incorrect_iter.dtype
2592 )
2593 acc_body_out = self.ser.addIntermediate(
2594 incorrect_acc.shape, incorrect_acc.dtype
2595 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002596 else:
2597 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2598 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2599
Eric Kunzee5e26762020-10-13 16:11:07 -07002600 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2601 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2602 self.ser.addOutputTensor(iter_body_out)
2603 self.ser.addOutputTensor(a)
2604 self.ser.addOutputTensor(acc_body_out)
2605
Les Bell729b0352021-11-24 10:28:21 +00002606 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002607 self.ser,
2608 validator_fcns,
2609 error_name,
2610 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002611 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002612 ):
2613 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002614
Jeremy Johnson587cc842024-02-08 11:45:44 +00002615 compliance = self.tensorComplianceMetaData(
2616 op, a.dtype, args_dict, acc_out, error_name
2617 )
2618
2619 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002620
Luke Hutton57287132023-02-06 14:54:18 +00002621 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002622 self,
2623 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002624 inputs,
2625 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002626 validator_fcns=None,
2627 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002628 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002629 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002630 assert len(inputs) == 2
2631 val1, val2 = inputs
2632 inverse = args_dict["inverse"]
2633
Luke Hutton57287132023-02-06 14:54:18 +00002634 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2635
2636 input_names = [val1.name, val2.name]
2637 pCount, cCount = op["operands"]
2638 num_operands = pCount + cCount
2639
2640 output_names = [res.name for res in results]
2641 output_shapes = [res.shape for res in results]
2642 output_dtypes = [res.dtype for res in results]
2643
2644 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2645 self, error_name, input_names, output_names
2646 )
2647
2648 if not TosaErrorValidator.evValidateErrorIfs(
2649 self.ser,
2650 validator_fcns,
2651 error_name,
2652 op=op,
2653 inverse=inverse,
2654 input1=val1,
2655 input2=val2,
2656 input_shape=val1.shape,
2657 input_dtype=val1.dtype,
2658 output_shape=output_shapes,
2659 output_dtype=output_dtypes,
2660 result_tensors=results,
2661 input_list=input_names,
2662 output_list=output_names,
2663 num_operands=num_operands,
2664 ):
2665 return None
2666
Tai Lyd3797f02023-11-15 23:06:19 +00002667 # TODO - Test local_bound, for now set local bound attribute to False
2668 local_bound = False
2669
Luke Hutton57287132023-02-06 14:54:18 +00002670 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002671 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002672
2673 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002674
2675 compliance = []
2676 for res in results:
2677 compliance.append(
2678 self.tensorComplianceMetaData(
2679 op, val1.dtype, args_dict, res, error_name
2680 )
2681 )
2682
2683 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002684
Tai Lyd3797f02023-11-15 23:06:19 +00002685 def build_rfft2d(
2686 self,
2687 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002688 inputs,
2689 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002690 validator_fcns=None,
2691 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002692 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002693 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002694 assert len(inputs) == 1
2695 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002696 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2697
2698 input_names = [val.name]
2699 pCount, cCount = op["operands"]
2700 num_operands = pCount + cCount
2701
2702 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002703 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002704 output_dtypes = [res.dtype for res in results]
2705
2706 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2707 self, error_name, input_names, output_names
2708 )
2709
2710 if not TosaErrorValidator.evValidateErrorIfs(
2711 self.ser,
2712 validator_fcns,
2713 error_name,
2714 op=op,
2715 input_shape=val.shape,
2716 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002717 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002718 output_dtype=output_dtypes,
2719 result_tensors=results,
2720 input_list=input_names,
2721 output_list=output_names,
2722 num_operands=num_operands,
2723 ):
2724 return None
2725
Tai Lyd3797f02023-11-15 23:06:19 +00002726 # TODO - Test local_bound, for now set local bound attribute to False
2727 local_bound = False
2728
2729 attr = ts.TosaSerializerAttribute()
2730 attr.RFFTAttribute(local_bound)
2731
2732 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002733
2734 compliance = []
2735 for res in results:
2736 compliance.append(
2737 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2738 )
2739
2740 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002741
Won Jeon74342e52024-01-09 00:34:40 +00002742 def build_shape_op(
2743 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2744 ):
2745 assert len(inputs) == 2
2746 a, b = inputs
2747
2748 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2749
2750 # Invalidate Input/Output list for error if checks.
2751 input_list = [a.name, b.name]
2752 output_list = [result_tensor.name]
2753 pCount, cCount = op["operands"]
2754 num_operands = pCount + cCount
2755 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2756 self, error_name, input_list, output_list
2757 )
2758
2759 if not TosaErrorValidator.evValidateErrorIfs(
2760 self.ser,
2761 validator_fcns,
2762 error_name,
2763 op=op,
2764 input1=a,
2765 input2=b,
2766 input_shape=a.shape,
2767 input_dtype=a.dtype,
2768 output_shape=result_tensor.shape,
2769 output_dtype=result_tensor.dtype,
2770 result_tensors=[result_tensor],
2771 input_list=input_list,
2772 output_list=output_list,
2773 num_operands=num_operands,
2774 ):
2775 return None
2776
2777 self.ser.addOperator(
2778 op["op"],
2779 input_list,
2780 output_list,
2781 )
2782 compliance = self.tensorComplianceMetaData(
2783 op, a.dtype, args_dict, result_tensor, error_name
2784 )
2785
2786 return TosaTestGen.BuildInfo(result_tensor, compliance)
2787
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002788 def create_filter_lists(
2789 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2790 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002791 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2792 default_test_rank_range = range(1, 5)
2793 if not shapeFilter:
2794 shapeFilter = [None]
2795
2796 # Calculate the filters based on what is requested and what the operator allows
2797 rmin, rmax = op["rank"]
2798 if rankFilter is not None:
2799 cleanRankFilter = []
2800 # Ensure rankFilter values are allowed by operator
2801 for rank in rankFilter:
2802 if rank >= rmin and rank <= rmax:
2803 cleanRankFilter.append(rank)
2804 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002805 # Ensure default behaviour is bounded by default range or by operator,
2806 # whichever is the smaller range of ranks.
2807 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002808 cleanRankFilter = (
2809 opRankRange
2810 if len(opRankRange) <= len(default_test_rank_range)
2811 else default_test_rank_range
2812 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002813 else:
2814 cleanRankFilter = range(rmin, rmax + 1)
2815
2816 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002817
Matthew Haddon1c00b712021-10-01 15:51:03 +01002818 if dtypeFilter is not None:
2819 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002820 # Create list of operator dtypes filtered by requested dtypes
2821 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002822 if dtype in dtypeFilter or (
2823 isinstance(dtype, list) and dtype[0] in dtypeFilter
2824 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002825 cleanDtypeFilter.append(dtype)
2826 else:
2827 cleanDtypeFilter = dtypes
2828
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002829 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002830 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002831 "shapeFilter": shapeFilter,
2832 "rankFilter": cleanRankFilter,
2833 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002834 }
2835 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002836 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002837 if validator is not None:
2838 validator_info = validator(check=False, op=op)
2839 else:
2840 return None
2841
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002842 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002843
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002844 # Set parameters as required
2845 if error_arguments["rank"] is not None:
2846 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002847 else:
2848 rankFilter = cleanRankFilter
2849
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002850 if error_arguments["dtype"] is not None:
2851 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002852 else:
2853 dtypeFilter = cleanDtypeFilter
2854
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002855 if error_arguments["shape"] is not None:
2856 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002857 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002858 shapeFilter = shapeFilter[
2859 :2
2860 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002861
2862 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002863 "shapeFilter": shapeFilter,
2864 "rankFilter": rankFilter,
2865 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002866 }
2867 return filterDict
2868
Kevin Cheng550ccc52021-03-03 11:21:43 -08002869 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 self,
2871 opName,
2872 shapeFilter=[None],
2873 rankFilter=None,
2874 dtypeFilter=None,
2875 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002876 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002877
2878 try:
2879 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002880 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002881 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002882
2883 # Initialize a new random number generator
2884 self.rng = np.random.default_rng(self.random_seed)
2885
Jeremy Johnson1271c442023-09-05 11:39:26 +01002886 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002887
Eric Kunzee5e26762020-10-13 16:11:07 -07002888 # Test list consists of a tuple of:
2889 # (opName, testNameStr, dtype, shapeList, argumentsList)
2890 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002891 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002892 error_if_validators = op["error_if_validators"]
2893 else:
2894 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002895
Matthew Haddon1c00b712021-10-01 15:51:03 +01002896 for validator in error_if_validators:
2897 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002898 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002899 else:
2900 error_name = None
2901
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002902 filterDict = self.create_filter_lists(
2903 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2904 )
2905 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002906 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002907 cleanRankFilter = filterDict["rankFilter"]
2908 cleanDtypeFilter = filterDict["dtypeFilter"]
2909 cleanShapeFilter = filterDict["shapeFilter"]
2910 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002911
2912 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002913 for t in cleanDtypeFilter:
2914 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002915 # Filter out by rank
2916 if shape is not None and len(shape) != r:
2917 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002918 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002919 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002920
Matthew Haddon74567092021-07-16 15:38:20 +01002921 shapeStr = self.shapeStr(shapeList[0])
2922 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002923
Matthew Haddon74567092021-07-16 15:38:20 +01002924 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2925 argList = []
2926 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002927 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002928 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002929 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002930
Matthew Haddon74567092021-07-16 15:38:20 +01002931 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002932 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002933 if argStr:
2934 testStr = "{}_{}_{}_{}".format(
2935 opName, shapeStr, typeStr, argStr
2936 )
2937 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002938 testStr = "{}_{}_{}".format(
2939 opName, shapeStr, typeStr
2940 )
2941 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002942 if argStr:
2943 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2944 opName, error_name, shapeStr, typeStr, argStr
2945 )
2946 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002947 testStr = "{}_ERRORIF_{}_{}_{}".format(
2948 opName, error_name, shapeStr, typeStr
2949 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002950
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002951 testList.append(
2952 (opName, testStr, t, error_name, shapeList, args)
2953 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002954
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002955 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002956 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2957 if "invalid_test_validators" in op:
2958 invalid_test_validators = op["invalid_test_validators"]
2959 clean_testList = []
2960 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002961 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002962 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002963 if validator_fcn(
2964 opName=test[0],
2965 input_dtype=test[2],
2966 shapeList=test[4],
2967 args=test[5],
2968 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002969 remove_test = True
2970 if not remove_test:
2971 clean_testList.append(test)
2972 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002973
2974 return testList
2975
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002976 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002977 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002978 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002979 try:
2980 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002981 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002982 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002983
Jeremy Johnson0c716862023-04-13 17:18:19 +01002984 if self.args.verbose:
2985 print(f"Creating {testStr}")
2986
Eric Kunzee5e26762020-10-13 16:11:07 -07002987 # Create a serializer
2988 self.createSerializer(opName, testStr)
2989
Jeremy Johnson1271c442023-09-05 11:39:26 +01002990 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002991 if "error_if_validators" in op:
2992 error_if_validators = op["error_if_validators"]
2993 else:
2994 error_if_validators = None
2995
Kevin Cheng550ccc52021-03-03 11:21:43 -08002996 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002997 num_operands = pCount + cCount
2998
2999 if isinstance(dtype_or_dtypeList, list):
3000 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003001 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003002 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003003 else:
3004 dtypeList = [dtype_or_dtypeList] * (num_operands)
3005
Won Jeon74342e52024-01-09 00:34:40 +00003006 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003007 assert (
3008 len(shapeList) == num_operands
3009 ), "shapeList length {} must match number of operands {}".format(
3010 len(shapeList), num_operands
3011 )
3012 assert (
3013 len(dtypeList) == num_operands
3014 ), "dtypeList length {} must match number of operands {}".format(
3015 len(dtypeList), num_operands
3016 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003017
3018 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003019 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003020 except KeyError:
3021 qgen = None
3022
3023 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003024
Matthew Haddon1c00b712021-10-01 15:51:03 +01003025 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003026 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003027 else:
3028 qinfo = None
3029
Jeremy Johnson1271c442023-09-05 11:39:26 +01003030 # Extra meta data for the desc.json
3031 tensMeta = {}
3032
Jeremy Johnson587cc842024-02-08 11:45:44 +00003033 # Check we are using the new interface with an argsDict dictionary
3034 assert isinstance(
3035 argsDict, dict
3036 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003037
Jeremy Johnson587cc842024-02-08 11:45:44 +00003038 # New interface with args info in dictionary
3039 assert "dg_type" in argsDict
3040 tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
3041 if tvgInfo.dataGenDict:
3042 tensMeta["data_gen"] = tvgInfo.dataGenDict
3043 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003044
Jeremy Johnson587cc842024-02-08 11:45:44 +00003045 result = build_fcn(
3046 self,
3047 op,
3048 tens,
3049 argsDict,
3050 validator_fcns=error_if_validators,
3051 error_name=error_name,
3052 qinfo=qinfo,
3053 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003054
Jeremy Johnson1271c442023-09-05 11:39:26 +01003055 if result:
Les Bell729b0352021-11-24 10:28:21 +00003056 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003057 if isinstance(result, TosaTestGen.BuildInfo):
3058 # Add the compliance meta data (if any)
3059 compliance = result.getComplianceInfo()
3060 if compliance:
3061 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003062 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003063 else:
3064 # The test is not valid
3065 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003066
Eric Kunzee5e26762020-10-13 16:11:07 -07003067 def createDynamicOpLists(self):
3068
Jeremy Johnson00423432022-09-12 17:27:37 +01003069 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3070 # Already created these lists (can occur when class is initialized more than once)
3071 return
3072
Eric Kunzee5e26762020-10-13 16:11:07 -07003073 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003074 if not self.args.level8k:
3075 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3076 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3077 else:
3078 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3079 KERNELS_2D = [[1, bigK], [bigK, 2]]
3080 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003081
Kevin Cheng1533b852021-09-01 12:51:58 -07003082 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003083 testName = "conv2d_{}x{}".format(k[0], k[1])
3084 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3085 self.TOSA_OP_LIST[testName]["filter"] = k
3086 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003087
Kevin Cheng550ccc52021-03-03 11:21:43 -08003088 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3089 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3090 "depthwise_conv2d_TEMPLATE"
3091 ].copy()
3092 self.TOSA_OP_LIST[testName]["filter"] = k
3093 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003094
Kevin Cheng550ccc52021-03-03 11:21:43 -08003095 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3096 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3097 "transpose_conv2d_TEMPLATE"
3098 ].copy()
3099 self.TOSA_OP_LIST[testName]["filter"] = k
3100 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003101
Kevin Cheng1533b852021-09-01 12:51:58 -07003102 for k in KERNELS_3D:
3103 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3104 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3105 self.TOSA_OP_LIST[testName]["filter"] = k
3106 self.TOSA_OP_LIST[testName]["template"] = False
3107
Eric Kunzee5e26762020-10-13 16:11:07 -07003108 # Delete any templates after having created any dynamic ops
3109 # This is a two-pass operation because it's bad practice to delete
3110 # keys from dictionaries while iterating
3111 keyList = []
3112 for k in self.TOSA_OP_LIST:
3113 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003114 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003115 keyList.append(k)
3116 continue
3117 except KeyError:
3118 pass
3119
3120 for k in keyList:
3121 del self.TOSA_OP_LIST[k]
3122
3123 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003124 """Fill in default fields for ops if they aren't already specified.
3125 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003126 for op in self.TOSA_OP_LIST:
3127
3128 # Required fields
3129 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003130 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003131 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003132 raise Exception(
3133 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3134 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003135
3136 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003137 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003138 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003139 raise Exception(
3140 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3141 op
3142 )
3143 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003144
3145 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003146 _ = self.TOSA_OP_LIST[op]["types"]
3147 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003148 raise Exception(
3149 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3150 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003151
3152 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003153 _ = self.TOSA_OP_LIST[op]["op"]
3154 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003155 raise Exception(
3156 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3157 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003158
3159 # Put in default rank range, if missing
3160 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003161 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003162 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003163 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003164
3165 # Tensor operator list
3166 # 'op': op name
3167 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003168 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3169 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003170 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3171 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003172 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003173
Kevin Cheng550ccc52021-03-03 11:21:43 -08003174 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003175 TYPE_INT_FP = [
3176 DType.INT8,
3177 DType.INT16,
3178 DType.INT32,
3179 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003180 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003181 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003182 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003183
Kevin Cheng550ccc52021-03-03 11:21:43 -08003184 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003185 TYPE_FI32 = [
3186 DType.FP32,
3187 DType.FP16,
3188 DType.BF16,
3189 DType.INT32,
3190 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003191 TYPE_FIB = [
3192 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003193 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003194 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003195 DType.INT8,
3196 DType.INT16,
3197 DType.INT32,
3198 DType.BOOL,
3199 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003200 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003201
Won Jeon2c34b462024-02-06 18:37:00 +00003202 TYPE_NARROW_INT_FP = [
3203 DType.INT8,
3204 DType.INT16,
3205 DType.FP16,
3206 DType.BF16,
3207 DType.FP32,
3208 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003209
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003210 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003211 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003212 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003213 [DType.INT8, DType.INT8, DType.INT32],
3214 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003215 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003216 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003217 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003218 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003219 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3220 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003221 ]
3222
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003223 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003224
3225 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003226 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003227 "argmax": {
3228 "op": Op.ARGMAX,
3229 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003230 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003231 "build_fcn": (
3232 build_argmax,
3233 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003234 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003235 TosaArgGen.agAxis,
3236 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003237 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003238 "error_if_validators": (
3239 TosaErrorValidator.evAxisSmallerZero,
3240 TosaErrorValidator.evAxisLargerRank,
3241 TosaErrorValidator.evArgmaxOutputRankMismatch,
3242 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3243 TosaErrorValidator.evWrongRank,
3244 TosaErrorValidator.evWrongInputType,
3245 TosaErrorValidator.evWrongOutputType,
3246 TosaErrorValidator.evWrongInputList,
3247 TosaErrorValidator.evWrongOutputList,
3248 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003249 "data_gen": {
3250 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3251 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003252 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003253 "avg_pool2d": {
3254 "op": Op.AVG_POOL2D,
3255 "operands": (1, 0),
3256 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003257 "build_fcn": (
3258 build_pool2d,
3259 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003260 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003261 TosaArgGen.agPooling,
3262 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003264 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003265 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003266 "error_if_validators": (
3267 TosaErrorValidator.evKernelSmallerOne,
3268 TosaErrorValidator.evStrideSmallerOne,
3269 TosaErrorValidator.evPadSmallerZero,
3270 TosaErrorValidator.evWrongRank,
3271 TosaErrorValidator.evWrongInputType,
3272 TosaErrorValidator.evWrongOutputType,
3273 TosaErrorValidator.evWrongInputList,
3274 TosaErrorValidator.evWrongOutputList,
3275 TosaErrorValidator.evInputZeroPointNotZero,
3276 TosaErrorValidator.evOutputZeroPointNotZero,
3277 TosaErrorValidator.evPadLargerEqualKernel,
3278 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003279 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003280 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003281 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003282 "data_gen": {
3283 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3284 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003285 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003286 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003287 "conv2d_TEMPLATE": {
3288 "op": Op.CONV2D,
3289 "operands": (1, 2),
3290 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003291 "build_fcn": (
3292 build_conv2d,
3293 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003294 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003295 TosaArgGen.agConv,
3296 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003297 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003298 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003299 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3300 "error_if_validators": (
3301 TosaErrorValidator.evWrongInputType,
3302 TosaErrorValidator.evWrongOutputType,
3303 TosaErrorValidator.evWrongInputList,
3304 TosaErrorValidator.evWrongOutputList,
3305 TosaErrorValidator.evInputZeroPointNotZero,
3306 TosaErrorValidator.evWeightZeroPointNotZero,
3307 TosaErrorValidator.evPadSmallerZero,
3308 TosaErrorValidator.evStrideSmallerOne,
3309 TosaErrorValidator.evDilationSmallerOne,
3310 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003311 TosaErrorValidator.evConvOutputShapeMismatch,
3312 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003313 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003314 "data_gen": {
3315 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3316 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003317 "template": True,
3318 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003319 # Templated operator. Filled in by createDynamicOpLists
3320 "conv3d_TEMPLATE": {
3321 "op": Op.CONV3D,
3322 "operands": (1, 2),
3323 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003324 "build_fcn": (
3325 build_conv3d,
3326 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003327 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003328 TosaArgGen.agConv,
3329 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003330 "qgen": TosaQuantGen.qgConv,
3331 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003332 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3333 "error_if_validators": (
3334 TosaErrorValidator.evWrongInputType,
3335 TosaErrorValidator.evWrongOutputType,
3336 TosaErrorValidator.evWrongInputList,
3337 TosaErrorValidator.evWrongOutputList,
3338 TosaErrorValidator.evInputZeroPointNotZero,
3339 TosaErrorValidator.evWeightZeroPointNotZero,
3340 TosaErrorValidator.evPadSmallerZero,
3341 TosaErrorValidator.evStrideSmallerOne,
3342 TosaErrorValidator.evDilationSmallerOne,
3343 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003344 TosaErrorValidator.evConvOutputShapeMismatch,
3345 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003346 ),
evacha0147ab1762024-01-29 13:23:23 +00003347 "data_gen": {
3348 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3349 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003350 "template": True,
3351 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003352 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003353 "depthwise_conv2d_TEMPLATE": {
3354 "op": Op.DEPTHWISE_CONV2D,
3355 "operands": (1, 2),
3356 "filter": [1, 1],
3357 "rank": (4, 4),
3358 "build_fcn": (
3359 build_depthwise_conv2d,
3360 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003361 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003362 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003363 ),
3364 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003365 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003366 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3367 "error_if_validators": (
3368 TosaErrorValidator.evWrongInputType,
3369 TosaErrorValidator.evWrongOutputType,
3370 TosaErrorValidator.evWrongInputList,
3371 TosaErrorValidator.evWrongOutputList,
3372 TosaErrorValidator.evInputZeroPointNotZero,
3373 TosaErrorValidator.evWeightZeroPointNotZero,
3374 TosaErrorValidator.evPadSmallerZero,
3375 TosaErrorValidator.evStrideSmallerOne,
3376 TosaErrorValidator.evDilationSmallerOne,
3377 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003378 TosaErrorValidator.evConvOutputShapeMismatch,
3379 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003380 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003381 "data_gen": {
3382 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3383 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003384 "template": True,
3385 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003386 "fully_connected": {
3387 "op": Op.FULLY_CONNECTED,
3388 "operands": (1, 2),
3389 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003390 "build_fcn": (
3391 build_fully_connected,
3392 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003393 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003394 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003395 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003397 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003398 "error_if_validators": (
3399 TosaErrorValidator.evInputZeroPointNotZero,
3400 TosaErrorValidator.evWeightZeroPointNotZero,
3401 TosaErrorValidator.evWrongRank,
3402 TosaErrorValidator.evWrongInputType,
3403 TosaErrorValidator.evWrongOutputType,
3404 TosaErrorValidator.evWrongInputList,
3405 TosaErrorValidator.evWrongOutputList,
3406 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003407 "data_gen": {
3408 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3409 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003410 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003411 "matmul": {
3412 "op": Op.MATMUL,
3413 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003414 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003415 "build_fcn": (
3416 build_matmul,
3417 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003418 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003419 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003420 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003421 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003422 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003423 "error_if_validators": (
3424 TosaErrorValidator.evInputZeroPointNotZero,
3425 TosaErrorValidator.evWrongRank,
3426 TosaErrorValidator.evWrongInputType,
3427 TosaErrorValidator.evWrongOutputType,
3428 TosaErrorValidator.evWrongInputList,
3429 TosaErrorValidator.evWrongOutputList,
3430 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003431 "data_gen": {
3432 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003435 "max_pool2d": {
3436 "op": Op.MAX_POOL2D,
3437 "operands": (1, 0),
3438 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003440 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003441 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003442 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003443 TosaArgGen.agPooling,
3444 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003445 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003446 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003447 "error_if_validators": (
3448 TosaErrorValidator.evKernelSmallerOne,
3449 TosaErrorValidator.evStrideSmallerOne,
3450 TosaErrorValidator.evPadSmallerZero,
3451 TosaErrorValidator.evWrongRank,
3452 TosaErrorValidator.evWrongInputType,
3453 TosaErrorValidator.evWrongOutputType,
3454 TosaErrorValidator.evWrongInputList,
3455 TosaErrorValidator.evWrongOutputList,
3456 TosaErrorValidator.evPadLargerEqualKernel,
3457 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003458 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003459 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003460 "data_gen": {
3461 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3462 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003464 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003465 "transpose_conv2d_TEMPLATE": {
3466 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003467 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003468 "rank": (4, 4),
3469 "build_fcn": (
3470 build_transpose_conv2d,
3471 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003472 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003473 TosaArgGen.agTransposeConv2D,
3474 ),
3475 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003476 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003477 "invalid_test_validators": (
3478 TosaInvalidValidator.ivHeightWidthInvalid,
3479 TosaInvalidValidator.ivNonPositiveOutputShape,
3480 ),
3481 "error_if_validators": (
3482 TosaErrorValidator.evWrongInputType,
3483 TosaErrorValidator.evWrongOutputType,
3484 TosaErrorValidator.evWrongInputList,
3485 TosaErrorValidator.evWrongOutputList,
3486 TosaErrorValidator.evInputZeroPointNotZero,
3487 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003488 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003489 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003490 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003491 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003492 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003493 "data_gen": {
3494 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3495 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003496 "template": True,
3497 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003498 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003499 "clamp": {
3500 "op": Op.CLAMP,
3501 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003502 "build_fcn": (
3503 build_clamp,
3504 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003505 TosaTensorValuesGen.tvgLazyGenDefault,
3506 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003507 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003508 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003509 "error_if_validators": (
3510 TosaErrorValidator.evMaxSmallerMin,
3511 TosaErrorValidator.evWrongInputType,
3512 TosaErrorValidator.evWrongOutputType,
3513 TosaErrorValidator.evWrongInputList,
3514 TosaErrorValidator.evWrongOutputList,
3515 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003516 "data_gen": {
3517 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3518 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003519 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003520 "sigmoid": {
3521 "op": Op.SIGMOID,
3522 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003523 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003524 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003525 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003526 TosaTensorValuesGen.tvgLazyGenDefault,
3527 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003528 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003529 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003530 "error_if_validators": (
3531 TosaErrorValidator.evWrongInputType,
3532 TosaErrorValidator.evWrongOutputType,
3533 TosaErrorValidator.evWrongInputList,
3534 TosaErrorValidator.evWrongOutputList,
3535 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003536 "data_gen": {
3537 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3538 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003539 },
3540 "tanh": {
3541 "op": Op.TANH,
3542 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003543 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003544 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003545 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003546 TosaTensorValuesGen.tvgLazyGenDefault,
3547 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003548 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003549 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003550 "error_if_validators": (
3551 TosaErrorValidator.evWrongInputType,
3552 TosaErrorValidator.evWrongOutputType,
3553 TosaErrorValidator.evWrongInputList,
3554 TosaErrorValidator.evWrongOutputList,
3555 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003556 "data_gen": {
3557 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3558 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003559 "compliance": {
3560 "abs_error_lower_bound": 0.5,
3561 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003562 },
Won Jeon78155c62023-06-10 00:20:04 +00003563 "erf": {
3564 "op": Op.ERF,
3565 "operands": (1, 0),
3566 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003567 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003568 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003569 TosaTensorValuesGen.tvgLazyGenDefault,
3570 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003571 ),
3572 "types": TYPE_FP,
3573 "error_if_validators": (
3574 TosaErrorValidator.evWrongInputType,
3575 TosaErrorValidator.evWrongOutputType,
3576 TosaErrorValidator.evWrongInputList,
3577 TosaErrorValidator.evWrongOutputList,
3578 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003579 "data_gen": {
3580 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3581 },
3582 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003583 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003584 # Elementwise Binary Operators
3585 "add": {
3586 "op": Op.ADD,
3587 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003588 "build_fcn": (
3589 build_binary_broadcast,
3590 TosaTensorGen.tgBroadcastFuzz,
3591 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003592 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003593 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003594 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003595 "error_if_validators": (
3596 TosaErrorValidator.evRankMismatch,
3597 TosaErrorValidator.evWrongInputType,
3598 TosaErrorValidator.evWrongOutputType,
3599 TosaErrorValidator.evWrongInputList,
3600 TosaErrorValidator.evWrongOutputList,
3601 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003602 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003603 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003604 "data_gen": {
3605 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3606 },
3607 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003608 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003609 "arithmetic_right_shift": {
3610 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3611 "operands": (2, 0),
3612 "build_fcn": (
3613 build_arithmetic_right_shift,
3614 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003615 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003616 TosaArgGen.agArithmeticRightShift,
3617 ),
3618 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003619 "error_if_validators": (
3620 TosaErrorValidator.evRankMismatch,
3621 TosaErrorValidator.evWrongInputType,
3622 TosaErrorValidator.evWrongOutputType,
3623 TosaErrorValidator.evWrongInputList,
3624 TosaErrorValidator.evWrongOutputList,
3625 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003626 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003627 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003629 "bitwise_and": {
3630 "op": Op.BITWISE_AND,
3631 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003632 "build_fcn": (
3633 build_binary_broadcast,
3634 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003635 TosaTensorValuesGen.tvgLazyGenDefault,
3636 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003638 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003639 "error_if_validators": (
3640 TosaErrorValidator.evRankMismatch,
3641 TosaErrorValidator.evWrongInputType,
3642 TosaErrorValidator.evWrongOutputType,
3643 TosaErrorValidator.evWrongInputList,
3644 TosaErrorValidator.evWrongOutputList,
3645 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003646 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003647 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003648 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003649 "bitwise_or": {
3650 "op": Op.BITWISE_OR,
3651 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003652 "build_fcn": (
3653 build_binary_broadcast,
3654 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003655 TosaTensorValuesGen.tvgLazyGenDefault,
3656 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003657 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003658 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003659 "error_if_validators": (
3660 TosaErrorValidator.evRankMismatch,
3661 TosaErrorValidator.evWrongInputType,
3662 TosaErrorValidator.evWrongOutputType,
3663 TosaErrorValidator.evWrongInputList,
3664 TosaErrorValidator.evWrongOutputList,
3665 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003666 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003667 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 "bitwise_xor": {
3670 "op": Op.BITWISE_XOR,
3671 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 "build_fcn": (
3673 build_binary_broadcast,
3674 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003675 TosaTensorValuesGen.tvgLazyGenDefault,
3676 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 "error_if_validators": (
3680 TosaErrorValidator.evRankMismatch,
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003686 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003687 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003688 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003689 "intdiv": {
3690 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003691 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003692 "build_fcn": (
3693 build_binary_broadcast,
3694 TosaTensorGen.tgBroadcastFuzz,
3695 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003696 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003698 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003699 "error_if_validators": (
3700 TosaErrorValidator.evRankMismatch,
3701 TosaErrorValidator.evWrongInputType,
3702 TosaErrorValidator.evWrongOutputType,
3703 TosaErrorValidator.evWrongInputList,
3704 TosaErrorValidator.evWrongOutputList,
3705 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003706 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003707 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003708 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003709 "logical_and": {
3710 "op": Op.LOGICAL_AND,
3711 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003712 "build_fcn": (
3713 build_binary_broadcast,
3714 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003715 TosaTensorValuesGen.tvgLazyGenDefault,
3716 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003717 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003718 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003719 "error_if_validators": (
3720 TosaErrorValidator.evRankMismatch,
3721 TosaErrorValidator.evWrongInputType,
3722 TosaErrorValidator.evWrongOutputType,
3723 TosaErrorValidator.evWrongInputList,
3724 TosaErrorValidator.evWrongOutputList,
3725 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003726 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003727 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003728 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003729 "logical_left_shift": {
3730 "op": Op.LOGICAL_LEFT_SHIFT,
3731 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 "build_fcn": (
3733 build_binary_broadcast,
3734 TosaTensorGen.tgBroadcastFuzz,
3735 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003736 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003737 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003739 "error_if_validators": (
3740 TosaErrorValidator.evRankMismatch,
3741 TosaErrorValidator.evWrongInputType,
3742 TosaErrorValidator.evWrongOutputType,
3743 TosaErrorValidator.evWrongInputList,
3744 TosaErrorValidator.evWrongOutputList,
3745 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003746 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003747 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003749 "logical_right_shift": {
3750 "op": Op.LOGICAL_RIGHT_SHIFT,
3751 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 "build_fcn": (
3753 build_binary_broadcast,
3754 TosaTensorGen.tgBroadcastFuzz,
3755 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003756 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003757 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 "error_if_validators": (
3760 TosaErrorValidator.evRankMismatch,
3761 TosaErrorValidator.evWrongInputType,
3762 TosaErrorValidator.evWrongOutputType,
3763 TosaErrorValidator.evWrongInputList,
3764 TosaErrorValidator.evWrongOutputList,
3765 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003766 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003767 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003768 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003769 "logical_or": {
3770 "op": Op.LOGICAL_OR,
3771 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 "build_fcn": (
3773 build_binary_broadcast,
3774 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003775 TosaTensorValuesGen.tvgLazyGenDefault,
3776 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003777 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003778 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 "error_if_validators": (
3780 TosaErrorValidator.evRankMismatch,
3781 TosaErrorValidator.evWrongInputType,
3782 TosaErrorValidator.evWrongOutputType,
3783 TosaErrorValidator.evWrongInputList,
3784 TosaErrorValidator.evWrongOutputList,
3785 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003786 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003787 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003788 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003789 "logical_xor": {
3790 "op": Op.LOGICAL_XOR,
3791 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003792 "build_fcn": (
3793 build_binary_broadcast,
3794 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003795 TosaTensorValuesGen.tvgLazyGenDefault,
3796 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003797 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003798 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003799 "error_if_validators": (
3800 TosaErrorValidator.evRankMismatch,
3801 TosaErrorValidator.evWrongInputType,
3802 TosaErrorValidator.evWrongOutputType,
3803 TosaErrorValidator.evWrongInputList,
3804 TosaErrorValidator.evWrongOutputList,
3805 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003806 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003807 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003809 "maximum": {
3810 "op": Op.MAXIMUM,
3811 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003812 "build_fcn": (
3813 build_binary_broadcast,
3814 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003815 TosaTensorValuesGen.tvgLazyGenDefault,
3816 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 "error_if_validators": (
3820 TosaErrorValidator.evRankMismatch,
3821 TosaErrorValidator.evWrongInputType,
3822 TosaErrorValidator.evWrongOutputType,
3823 TosaErrorValidator.evWrongInputList,
3824 TosaErrorValidator.evWrongOutputList,
3825 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003826 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003827 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003828 "data_gen": {
3829 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3830 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003831 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003832 "minimum": {
3833 "op": Op.MINIMUM,
3834 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003835 "build_fcn": (
3836 build_binary_broadcast,
3837 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003838 TosaTensorValuesGen.tvgLazyGenDefault,
3839 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003840 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003842 "error_if_validators": (
3843 TosaErrorValidator.evRankMismatch,
3844 TosaErrorValidator.evWrongInputType,
3845 TosaErrorValidator.evWrongOutputType,
3846 TosaErrorValidator.evWrongInputList,
3847 TosaErrorValidator.evWrongOutputList,
3848 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003849 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003850 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003851 "data_gen": {
3852 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3853 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003854 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003855 "mul": {
3856 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003857 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003858 "build_fcn": (
3859 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003860 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003861 TosaTensorValuesGen.tvgMul,
3862 TosaArgGen.agMul,
3863 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003865 "error_if_validators": (
3866 TosaErrorValidator.evWrongInputType,
3867 TosaErrorValidator.evWrongOutputType,
3868 TosaErrorValidator.evWrongInputList,
3869 TosaErrorValidator.evWrongOutputList,
3870 TosaErrorValidator.evRankMismatch,
3871 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003872 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003873 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003874 "data_gen": {
3875 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3876 },
3877 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003878 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003879 "pow": {
3880 "op": Op.POW,
3881 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003882 "build_fcn": (
3883 build_binary_broadcast,
3884 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003885 TosaTensorValuesGen.tvgPow,
3886 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003887 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003888 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003889 "error_if_validators": (
3890 TosaErrorValidator.evRankMismatch,
3891 TosaErrorValidator.evWrongInputType,
3892 TosaErrorValidator.evWrongOutputType,
3893 TosaErrorValidator.evWrongInputList,
3894 TosaErrorValidator.evWrongOutputList,
3895 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003896 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003897 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003898 "data_gen": {
3899 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3900 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003901 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003902 "sub": {
3903 "op": Op.SUB,
3904 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003905 "build_fcn": (
3906 build_binary_broadcast,
3907 TosaTensorGen.tgBroadcastFuzz,
3908 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003909 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003910 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003911 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003912 "error_if_validators": (
3913 TosaErrorValidator.evRankMismatch,
3914 TosaErrorValidator.evWrongInputType,
3915 TosaErrorValidator.evWrongOutputType,
3916 TosaErrorValidator.evWrongInputList,
3917 TosaErrorValidator.evWrongOutputList,
3918 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003919 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003920 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003921 "data_gen": {
3922 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3923 },
3924 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003925 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003926 "table": {
3927 "op": Op.TABLE,
3928 # Use the automatic generation functions to create the input array
3929 # but create the table tensor in the build function, as it may be
3930 # a different type from the input
3931 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003932 "build_fcn": (
3933 build_table,
3934 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003935 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003936 TosaArgGen.agTable,
3937 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003938 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 "error_if_validators": (
3940 TosaErrorValidator.evWrongInputType,
3941 TosaErrorValidator.evWrongOutputType,
3942 TosaErrorValidator.evWrongInputList,
3943 TosaErrorValidator.evWrongOutputList,
3944 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003945 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 # Elementwise Unary operators
3947 "abs": {
3948 "op": Op.ABS,
3949 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003950 "build_fcn": (
3951 build_unary,
3952 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003953 TosaTensorValuesGen.tvgLazyGenDefault,
3954 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003956 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003957 "error_if_validators": (
3958 TosaErrorValidator.evWrongInputType,
3959 TosaErrorValidator.evWrongOutputType,
3960 TosaErrorValidator.evWrongInputList,
3961 TosaErrorValidator.evWrongOutputList,
3962 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003963 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00003964 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003965 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003966 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003967 "bitwise_not": {
3968 "op": Op.BITWISE_NOT,
3969 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003970 "build_fcn": (
3971 build_unary,
3972 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003973 TosaTensorValuesGen.tvgLazyGenDefault,
3974 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003975 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003977 "error_if_validators": (
3978 TosaErrorValidator.evWrongInputType,
3979 TosaErrorValidator.evWrongOutputType,
3980 TosaErrorValidator.evWrongInputList,
3981 TosaErrorValidator.evWrongOutputList,
3982 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003983 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003984 "ceil": {
3985 "op": Op.CEIL,
3986 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003987 "build_fcn": (
3988 build_unary,
3989 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003990 TosaTensorValuesGen.tvgLazyGenDefault,
3991 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003992 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003993 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003994 "error_if_validators": (
3995 TosaErrorValidator.evWrongInputType,
3996 TosaErrorValidator.evWrongOutputType,
3997 TosaErrorValidator.evWrongInputList,
3998 TosaErrorValidator.evWrongOutputList,
3999 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004000 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004001 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004002 },
4003 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004004 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004005 "clz": {
4006 "op": Op.CLZ,
4007 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004008 "build_fcn": (
4009 build_unary,
4010 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004011 TosaTensorValuesGen.tvgLazyGenDefault,
4012 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004013 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004014 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004015 "error_if_validators": (
4016 TosaErrorValidator.evWrongInputType,
4017 TosaErrorValidator.evWrongOutputType,
4018 TosaErrorValidator.evWrongInputList,
4019 TosaErrorValidator.evWrongOutputList,
4020 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004021 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004022 "cos": {
4023 "op": Op.COS,
4024 "operands": (1, 0),
4025 "build_fcn": (
4026 build_unary,
4027 TosaTensorGen.tgBasic,
4028 TosaTensorValuesGen.tvgLazyGenDefault,
4029 TosaArgGen.agNone,
4030 ),
4031 "types": TYPE_FP,
4032 "error_if_validators": (
4033 TosaErrorValidator.evWrongInputType,
4034 TosaErrorValidator.evWrongOutputType,
4035 TosaErrorValidator.evWrongInputList,
4036 TosaErrorValidator.evWrongOutputList,
4037 ),
4038 "data_gen": {
4039 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4040 },
4041 "compliance": {"abs_error_normal_divisor": 2},
4042 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004043 "exp": {
4044 "op": Op.EXP,
4045 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004046 "build_fcn": (
4047 build_unary,
4048 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004049 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004050 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004051 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004052 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004053 "error_if_validators": (
4054 TosaErrorValidator.evWrongInputType,
4055 TosaErrorValidator.evWrongOutputType,
4056 TosaErrorValidator.evWrongInputList,
4057 TosaErrorValidator.evWrongOutputList,
4058 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004059 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004060 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004061 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004062 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004063 "floor": {
4064 "op": Op.FLOOR,
4065 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004066 "build_fcn": (
4067 build_unary,
4068 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004069 TosaTensorValuesGen.tvgLazyGenDefault,
4070 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004071 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004073 "error_if_validators": (
4074 TosaErrorValidator.evWrongInputType,
4075 TosaErrorValidator.evWrongOutputType,
4076 TosaErrorValidator.evWrongInputList,
4077 TosaErrorValidator.evWrongOutputList,
4078 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004079 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004080 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004081 },
4082 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004083 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004084 "log": {
4085 "op": Op.LOG,
4086 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004087 "build_fcn": (
4088 build_unary,
4089 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004090 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004091 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004092 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004093 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004094 "error_if_validators": (
4095 TosaErrorValidator.evWrongInputType,
4096 TosaErrorValidator.evWrongOutputType,
4097 TosaErrorValidator.evWrongInputList,
4098 TosaErrorValidator.evWrongOutputList,
4099 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004100 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004101 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004102 },
4103 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004104 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004105 "logical_not": {
4106 "op": Op.LOGICAL_NOT,
4107 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004108 "build_fcn": (
4109 build_unary,
4110 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004111 TosaTensorValuesGen.tvgLazyGenDefault,
4112 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004113 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004114 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004115 "error_if_validators": (
4116 TosaErrorValidator.evWrongInputType,
4117 TosaErrorValidator.evWrongOutputType,
4118 TosaErrorValidator.evWrongInputList,
4119 TosaErrorValidator.evWrongOutputList,
4120 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004121 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004122 "negate": {
4123 "op": Op.NEGATE,
4124 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004125 "build_fcn": (
4126 build_unary,
4127 TosaTensorGen.tgBasic,
4128 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004129 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004130 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004131 "qgen": TosaQuantGen.qgUnary,
4132 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004133 "error_if_validators": (
4134 TosaErrorValidator.evInputZeroPointNotZero,
4135 TosaErrorValidator.evOutputZeroPointNotZero,
4136 TosaErrorValidator.evWrongInputType,
4137 TosaErrorValidator.evWrongOutputType,
4138 TosaErrorValidator.evWrongInputList,
4139 TosaErrorValidator.evWrongOutputList,
4140 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004141 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004142 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004143 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004144 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004145 "reciprocal": {
4146 "op": Op.RECIPROCAL,
4147 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004148 "build_fcn": (
4149 build_unary,
4150 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004151 TosaTensorValuesGen.tvgLazyGenDefault,
4152 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004153 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004154 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004155 "error_if_validators": (
4156 TosaErrorValidator.evWrongInputType,
4157 TosaErrorValidator.evWrongOutputType,
4158 TosaErrorValidator.evWrongInputList,
4159 TosaErrorValidator.evWrongOutputList,
4160 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004161 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004162 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004163 },
4164 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004165 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004166 "rsqrt": {
4167 "op": Op.RSQRT,
4168 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004169 "build_fcn": (
4170 build_unary,
4171 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004172 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004173 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004174 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004175 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004176 "error_if_validators": (
4177 TosaErrorValidator.evWrongInputType,
4178 TosaErrorValidator.evWrongOutputType,
4179 TosaErrorValidator.evWrongInputList,
4180 TosaErrorValidator.evWrongOutputList,
4181 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004182 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004183 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004184 },
4185 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004186 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004187 "sin": {
4188 "op": Op.SIN,
4189 "operands": (1, 0),
4190 "build_fcn": (
4191 build_unary,
4192 TosaTensorGen.tgBasic,
4193 TosaTensorValuesGen.tvgLazyGenDefault,
4194 TosaArgGen.agNone,
4195 ),
4196 "types": TYPE_FP,
4197 "error_if_validators": (
4198 TosaErrorValidator.evWrongInputType,
4199 TosaErrorValidator.evWrongOutputType,
4200 TosaErrorValidator.evWrongInputList,
4201 TosaErrorValidator.evWrongOutputList,
4202 ),
4203 "data_gen": {
4204 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4205 },
4206 "compliance": {"abs_error_normal_divisor": 2},
4207 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004208 # Elementwise Ternary operators
4209 "select": {
4210 "op": Op.SELECT,
4211 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004212 "build_fcn": (
4213 build_select,
4214 TosaTensorGen.tgBroadcastFuzz,
4215 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004216 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004217 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004218 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004219 "error_if_validators": (
4220 TosaErrorValidator.evRankMismatch,
4221 TosaErrorValidator.evWrongInputType,
4222 TosaErrorValidator.evWrongOutputType,
4223 TosaErrorValidator.evWrongInputList,
4224 TosaErrorValidator.evWrongOutputList,
4225 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004226 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004227 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004228 "data_gen": {
4229 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4230 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004231 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004232 # Comparison operators
4233 "equal": {
4234 "op": Op.EQUAL,
4235 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004236 "build_fcn": (
4237 build_comparison,
4238 TosaTensorGen.tgBroadcastFuzz,
4239 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004240 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004241 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004242 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004243 "error_if_validators": (
4244 TosaErrorValidator.evRankMismatch,
4245 TosaErrorValidator.evWrongInputType,
4246 TosaErrorValidator.evWrongOutputType,
4247 TosaErrorValidator.evWrongInputList,
4248 TosaErrorValidator.evWrongOutputList,
4249 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004250 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004251 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004252 "data_gen": {
4253 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4254 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004255 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004256 "greater_equal": {
4257 "op": Op.GREATER_EQUAL,
4258 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004259 "build_fcn": (
4260 build_comparison,
4261 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004262 TosaTensorValuesGen.tvgLazyGenDefault,
4263 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004264 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004265 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004266 "error_if_validators": (
4267 TosaErrorValidator.evRankMismatch,
4268 TosaErrorValidator.evWrongInputType,
4269 TosaErrorValidator.evWrongOutputType,
4270 TosaErrorValidator.evWrongInputList,
4271 TosaErrorValidator.evWrongOutputList,
4272 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004273 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004274 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004275 "data_gen": {
4276 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4277 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004279 "greater": {
4280 "op": Op.GREATER,
4281 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004282 "build_fcn": (
4283 build_comparison,
4284 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004285 TosaTensorValuesGen.tvgLazyGenDefault,
4286 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004287 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004288 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004289 "error_if_validators": (
4290 TosaErrorValidator.evRankMismatch,
4291 TosaErrorValidator.evWrongInputType,
4292 TosaErrorValidator.evWrongOutputType,
4293 TosaErrorValidator.evWrongInputList,
4294 TosaErrorValidator.evWrongOutputList,
4295 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004296 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004297 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004298 "data_gen": {
4299 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4300 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004301 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004302 # Reduction operators
4303 "reduce_all": {
4304 "op": Op.REDUCE_ALL,
4305 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004306 "build_fcn": (
4307 build_reduce,
4308 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004309 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004310 TosaArgGen.agAxis,
4311 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004312 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004313 "error_if_validators": (
4314 TosaErrorValidator.evAxisLargerRank,
4315 TosaErrorValidator.evAxisSmallerZero,
4316 TosaErrorValidator.evShapeOfAxisNotOne,
4317 TosaErrorValidator.evWrongInputType,
4318 TosaErrorValidator.evWrongOutputType,
4319 TosaErrorValidator.evWrongRank,
4320 TosaErrorValidator.evWrongInputList,
4321 TosaErrorValidator.evWrongOutputList,
4322 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004323 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004324 "reduce_any": {
4325 "op": Op.REDUCE_ANY,
4326 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004327 "build_fcn": (
4328 build_reduce,
4329 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004330 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004331 TosaArgGen.agAxis,
4332 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004333 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004334 "error_if_validators": (
4335 TosaErrorValidator.evAxisLargerRank,
4336 TosaErrorValidator.evAxisSmallerZero,
4337 TosaErrorValidator.evShapeOfAxisNotOne,
4338 TosaErrorValidator.evWrongInputType,
4339 TosaErrorValidator.evWrongOutputType,
4340 TosaErrorValidator.evWrongRank,
4341 TosaErrorValidator.evWrongInputList,
4342 TosaErrorValidator.evWrongOutputList,
4343 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004344 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004345 "reduce_max": {
4346 "op": Op.REDUCE_MAX,
4347 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004348 "build_fcn": (
4349 build_reduce,
4350 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004351 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004352 TosaArgGen.agAxis,
4353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004354 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004355 "error_if_validators": (
4356 TosaErrorValidator.evAxisLargerRank,
4357 TosaErrorValidator.evAxisSmallerZero,
4358 TosaErrorValidator.evShapeOfAxisNotOne,
4359 TosaErrorValidator.evWrongInputType,
4360 TosaErrorValidator.evWrongOutputType,
4361 TosaErrorValidator.evWrongRank,
4362 TosaErrorValidator.evWrongInputList,
4363 TosaErrorValidator.evWrongOutputList,
4364 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004365 "data_gen": {
4366 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4367 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004368 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004369 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004370 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004371 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004372 "build_fcn": (
4373 build_reduce,
4374 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004375 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004376 TosaArgGen.agAxis,
4377 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004378 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004379 "error_if_validators": (
4380 TosaErrorValidator.evAxisLargerRank,
4381 TosaErrorValidator.evAxisSmallerZero,
4382 TosaErrorValidator.evShapeOfAxisNotOne,
4383 TosaErrorValidator.evWrongInputType,
4384 TosaErrorValidator.evWrongOutputType,
4385 TosaErrorValidator.evWrongRank,
4386 TosaErrorValidator.evWrongInputList,
4387 TosaErrorValidator.evWrongOutputList,
4388 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004389 "data_gen": {
4390 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4391 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004392 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004393 "reduce_product": {
4394 "op": Op.REDUCE_PRODUCT,
4395 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004396 "build_fcn": (
4397 build_reduce,
4398 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004399 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004400 TosaArgGen.agAxis,
4401 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004402 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004403 "error_if_validators": (
4404 TosaErrorValidator.evAxisLargerRank,
4405 TosaErrorValidator.evAxisSmallerZero,
4406 TosaErrorValidator.evShapeOfAxisNotOne,
4407 TosaErrorValidator.evWrongInputType,
4408 TosaErrorValidator.evWrongOutputType,
4409 TosaErrorValidator.evWrongRank,
4410 TosaErrorValidator.evWrongInputList,
4411 TosaErrorValidator.evWrongOutputList,
4412 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004413 "data_gen": {
4414 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4415 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004416 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004417 "reduce_sum": {
4418 "op": Op.REDUCE_SUM,
4419 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004420 "build_fcn": (
4421 build_reduce,
4422 TosaTensorGen.tgBasic,
4423 TosaTensorValuesGen.tvgReduceSum,
4424 TosaArgGen.agAxis,
4425 ),
James Ward24dbc422022-10-19 12:20:31 +01004426 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004427 "error_if_validators": (
4428 TosaErrorValidator.evAxisLargerRank,
4429 TosaErrorValidator.evAxisSmallerZero,
4430 TosaErrorValidator.evShapeOfAxisNotOne,
4431 TosaErrorValidator.evWrongInputType,
4432 TosaErrorValidator.evWrongOutputType,
4433 TosaErrorValidator.evWrongRank,
4434 TosaErrorValidator.evWrongInputList,
4435 TosaErrorValidator.evWrongOutputList,
4436 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004437 "data_gen": {
4438 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4439 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004440 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004441 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004442 "concat": {
4443 "op": Op.CONCAT,
4444 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004445 "build_fcn": (
4446 build_concat,
4447 TosaTensorGen.tgConcat,
4448 TosaTensorValuesGen.tvgConcat,
4449 TosaArgGen.agAxis,
4450 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004451 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004452 "error_if_validators": (
4453 TosaErrorValidator.evAxisLargerRank,
4454 TosaErrorValidator.evAxisSmallerZero,
4455 TosaErrorValidator.evConcatInputRankMismatch,
4456 TosaErrorValidator.evConcatShapeSumMismatch,
4457 TosaErrorValidator.evConcatInputDimMismatch,
4458 TosaErrorValidator.evWrongInputType,
4459 TosaErrorValidator.evWrongOutputType,
4460 TosaErrorValidator.evWrongOutputList,
4461 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004462 "data_gen": {
4463 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4464 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004465 },
4466 "pad": {
4467 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004468 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004469 "build_fcn": (
4470 build_pad,
4471 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004472 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004473 TosaArgGen.agPad,
4474 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004475 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004476 "error_if_validators": (
4477 TosaErrorValidator.evWrongInputType,
4478 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004479 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004480 TosaErrorValidator.evWrongOutputType,
4481 TosaErrorValidator.evWrongInputList,
4482 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004483 TosaErrorValidator.evRankMismatch,
4484 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004485 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004486 "data_gen": {
4487 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4488 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004489 },
Won Jeona21b2e82023-08-10 10:33:01 +00004490 "dim": {
4491 "op": Op.DIM,
4492 "operands": (1, 0),
4493 "build_fcn": (
4494 build_dim,
4495 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004496 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004497 TosaArgGen.agAxis,
4498 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004499 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004500 "error_if_validators": (
4501 TosaErrorValidator.evAxisLargerRank,
4502 TosaErrorValidator.evAxisSmallerZero,
4503 TosaErrorValidator.evWrongInputType,
4504 TosaErrorValidator.evWrongInputList,
4505 TosaErrorValidator.evWrongOutputList,
4506 TosaErrorValidator.evWrongRank,
4507 ),
4508 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004509 "reshape": {
4510 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004511 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004512 "build_fcn": (
4513 build_reshape,
4514 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004515 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004516 TosaArgGen.agReshape,
4517 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004518 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004519 "error_if_validators": (
4520 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4521 TosaErrorValidator.evWrongInputType,
4522 TosaErrorValidator.evWrongOutputType,
4523 TosaErrorValidator.evWrongInputList,
4524 TosaErrorValidator.evWrongOutputList,
4525 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004526 "data_gen": {
4527 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4528 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004529 },
4530 "reverse": {
4531 "op": Op.REVERSE,
4532 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004533 "build_fcn": (
4534 build_reverse,
4535 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004536 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004537 TosaArgGen.agAxis,
4538 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004539 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004540 "error_if_validators": (
4541 TosaErrorValidator.evAxisSmallerZero,
4542 TosaErrorValidator.evAxisLargerRank,
4543 TosaErrorValidator.evWrongInputType,
4544 TosaErrorValidator.evWrongOutputType,
4545 TosaErrorValidator.evWrongInputList,
4546 TosaErrorValidator.evWrongOutputList,
4547 ),
evacha0198477222024-01-26 12:25:32 +00004548 "data_gen": {
4549 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4550 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004551 },
4552 "slice": {
4553 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004554 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004555 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004556 "build_fcn": (
4557 build_slice,
4558 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004559 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004560 TosaArgGen.agSlice,
4561 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004562 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004563 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004564 # TODO Turn off these error categories for now as the reference
4565 # model cannot allocate memory space for empty tensor. We probably
4566 # can report an accurate error messege at the right place during
4567 # exeuction.
4568 # TosaErrorValidator.evStartSmallerZero,
4569 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004570 TosaErrorValidator.evStartSizeOutsideBounds,
4571 TosaErrorValidator.evSizeOutputShapeMismatch,
4572 TosaErrorValidator.evInputSizeStartLengthMismatch,
4573 TosaErrorValidator.evWrongRank,
4574 TosaErrorValidator.evWrongInputType,
4575 TosaErrorValidator.evWrongOutputType,
4576 TosaErrorValidator.evWrongInputList,
4577 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004578 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004579 ),
evacha017f7d4252024-01-24 12:08:09 +00004580 "data_gen": {
4581 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4582 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004583 },
4584 "tile": {
4585 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004586 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004587 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004588 "build_fcn": (
4589 build_tile,
4590 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004591 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004592 TosaArgGen.agTile,
4593 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004594 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004595 "error_if_validators": (
4596 TosaErrorValidator.evWrongInputType,
4597 TosaErrorValidator.evWrongOutputType,
4598 TosaErrorValidator.evWrongInputList,
4599 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004600 TosaErrorValidator.evRankMismatch,
4601 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004602 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004603 "data_gen": {
4604 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4605 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004606 },
4607 "transpose": {
4608 "op": Op.TRANSPOSE,
4609 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004610 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004611 "build_fcn": (
4612 build_transpose,
4613 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004614 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004615 TosaArgGen.agTranspose,
4616 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004617 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004618 "error_if_validators": (
4619 TosaErrorValidator.evIndexOutsideBounds,
4620 TosaErrorValidator.evIndexUsedTwice,
4621 TosaErrorValidator.evWrongInputType,
4622 TosaErrorValidator.evWrongOutputType,
4623 TosaErrorValidator.evWrongInputList,
4624 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004625 TosaErrorValidator.evWrongRank,
4626 TosaErrorValidator.evRankMismatch,
4627 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004628 ),
evacha0198477222024-01-26 12:25:32 +00004629 "data_gen": {
4630 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4631 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004632 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004633 # Data nodes
4634 "const": {
4635 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004636 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004637 "build_fcn": (
4638 build_const,
4639 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004640 TosaTensorValuesGen.tvgLazyGenDefault,
4641 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004642 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004643 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004644 "data_gen": {
4645 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4646 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004647 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004648 "identity": {
4649 "op": Op.IDENTITY,
4650 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004651 "build_fcn": (
4652 build_unary,
4653 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004654 TosaTensorValuesGen.tvgLazyGenDefault,
4655 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004656 ),
evacha011adff832024-03-06 17:33:44 +00004657 "types": TYPE_FIB + [DType.INT4, DType.INT48],
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004658 "data_gen": {
4659 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4660 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004661 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004662 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004663 "gather": {
4664 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004665 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004666 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004667 "build_fcn": (
4668 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004669 TosaTensorGen.tgGather,
4670 TosaTensorValuesGen.tvgGather,
4671 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004672 ),
James Ward24dbc422022-10-19 12:20:31 +01004673 "types": (
4674 DType.INT8,
4675 DType.INT16,
4676 DType.INT32,
4677 DType.FP16,
4678 DType.BF16,
4679 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004680 DType.FP8E4M3,
4681 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004682 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004683 "error_if_validators": (
4684 TosaErrorValidator.evWrongInputType,
4685 TosaErrorValidator.evWrongOutputType,
4686 TosaErrorValidator.evWrongInputList,
4687 TosaErrorValidator.evWrongOutputList,
4688 TosaErrorValidator.evWrongRank,
4689 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004690 "data_gen": {
4691 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4692 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004693 },
4694 "scatter": {
4695 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004696 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004697 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004698 "build_fcn": (
4699 build_scatter,
4700 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004701 TosaTensorValuesGen.tvgScatter,
4702 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004703 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004704 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004705 "error_if_validators": (
4706 TosaErrorValidator.evWrongInputType,
4707 TosaErrorValidator.evWrongOutputType,
4708 TosaErrorValidator.evWrongInputList,
4709 TosaErrorValidator.evWrongOutputList,
4710 TosaErrorValidator.evWrongRank,
4711 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004712 "data_gen": {
4713 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4714 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004715 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004716 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004717 "resize": {
4718 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004719 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004720 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004721 "build_fcn": (
4722 build_resize,
4723 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004724 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004725 TosaArgGen.agResize,
4726 ),
James Ward24dbc422022-10-19 12:20:31 +01004727 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004728 "invalid_test_validators": (
4729 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004730 ),
4731 "error_if_validators": (
4732 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004733 TosaErrorValidator.evScaleSmallerEqualZero,
4734 TosaErrorValidator.evScaleNLargerMax,
4735 TosaErrorValidator.evScaleDLargerMax,
4736 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004737 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004738 TosaErrorValidator.evBorderSmallerMin,
4739 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004740 TosaErrorValidator.evWrongInputType,
4741 TosaErrorValidator.evWrongOutputType,
4742 TosaErrorValidator.evWrongRank,
4743 TosaErrorValidator.evWrongInputList,
4744 TosaErrorValidator.evWrongOutputList,
4745 TosaErrorValidator.evBatchMismatch,
4746 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004747 TosaErrorValidator.evResizeOutputShapeMismatch,
4748 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004749 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004750 "data_gen": {
4751 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4752 },
4753 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004754 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004755 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004756 "cast": {
4757 "op": Op.CAST,
4758 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004759 "build_fcn": (
4760 build_cast,
4761 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004762 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004763 TosaArgGen.agCast,
4764 ),
James Ward8b390432022-08-12 20:48:56 +01004765 "types": (
4766 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004767 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004768 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004769 DType.INT8,
4770 DType.INT16,
4771 DType.INT32,
4772 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004773 DType.FP8E4M3,
4774 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004775 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004776 "error_if_validators": (
4777 TosaErrorValidator.evWrongInputType,
4778 TosaErrorValidator.evWrongOutputType,
4779 TosaErrorValidator.evWrongInputList,
4780 TosaErrorValidator.evWrongOutputList,
4781 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004782 "data_gen": {
4783 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4784 },
4785 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004786 },
4787 "rescale": {
4788 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004789 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004790 "build_fcn": (
4791 build_rescale,
4792 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004793 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004794 TosaArgGen.agRescale,
4795 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004796 "types": [
4797 DType.UINT8,
4798 DType.INT8,
4799 DType.INT16,
4800 DType.INT32,
4801 DType.INT48,
4802 DType.UINT16,
4803 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004804 "error_if_validators": (
4805 TosaErrorValidator.evInputZeroPointNotZero,
4806 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004807 TosaErrorValidator.evU16InputZeroPointNotValid,
4808 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004809 TosaErrorValidator.evScaleTrue,
4810 TosaErrorValidator.evScaleNotTrue,
4811 TosaErrorValidator.evWrongInputType,
4812 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004813 TosaErrorValidator.evWrongInputList,
4814 TosaErrorValidator.evWrongOutputList,
4815 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004816 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004817 # Custom
4818 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004819 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004820 # Two varients of cond_if, one that generates one of two constant tensors (no
4821 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4822 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004823 "cond_if_const": {
4824 "op": Op.COND_IF,
4825 "operands": (0, 2),
4826 "build_fcn": (
4827 build_cond_if_const,
4828 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004829 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004830 TosaArgGen.agCondIf,
4831 ),
4832 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004833 "error_if_validators": (
4834 TosaErrorValidator.evOutputListThenGraphMismatch,
4835 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004836 TosaErrorValidator.evCondIfCondNotMatchingBool,
4837 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004838 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004839 },
4840 "cond_if_binary": {
4841 "op": Op.COND_IF,
4842 "operands": (2, 0),
4843 "build_fcn": (
4844 build_cond_if_binary,
4845 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004846 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004847 TosaArgGen.agCondIf,
4848 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004849 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004850 "error_if_validators": (
4851 TosaErrorValidator.evInputListThenGraphMismatch,
4852 TosaErrorValidator.evInputListElseGraphMismatch,
4853 TosaErrorValidator.evOutputListThenGraphMismatch,
4854 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004855 TosaErrorValidator.evCondIfCondNotMatchingBool,
4856 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004857 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004858 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004859 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004860 "while_loop": {
4861 "op": Op.WHILE_LOOP,
4862 "operands": (0, 1),
4863 "build_fcn": (
4864 build_while_loop,
4865 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004866 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004867 TosaArgGen.agWhileLoop,
4868 ),
4869 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004870 "error_if_validators": (
4871 TosaErrorValidator.evInputListOutputListMismatch,
4872 TosaErrorValidator.evInputListCondGraphMismatch,
4873 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4874 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4875 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004876 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004877 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004878 },
Luke Hutton57287132023-02-06 14:54:18 +00004879 "fft2d": {
4880 "op": Op.FFT2D,
4881 "operands": (2, 0),
4882 "rank": (3, 3),
4883 "build_fcn": (
4884 build_fft2d,
4885 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004886 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004887 TosaArgGen.agFFT2d,
4888 ),
4889 "types": [DType.FP32],
4890 "error_if_validators": (
4891 TosaErrorValidator.evWrongInputType,
4892 TosaErrorValidator.evWrongOutputType,
4893 TosaErrorValidator.evWrongInputList,
4894 TosaErrorValidator.evWrongOutputList,
4895 TosaErrorValidator.evWrongRank,
4896 TosaErrorValidator.evBatchMismatch,
4897 TosaErrorValidator.evKernelNotPowerOfTwo,
4898 TosaErrorValidator.evFFTInputShapeMismatch,
4899 TosaErrorValidator.evFFTOutputShapeMismatch,
4900 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004901 "data_gen": {
4902 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4903 },
Luke Hutton57287132023-02-06 14:54:18 +00004904 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004905 "rfft2d": {
4906 "op": Op.RFFT2D,
4907 "operands": (1, 0),
4908 "rank": (3, 3),
4909 "build_fcn": (
4910 build_rfft2d,
4911 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004912 TosaTensorValuesGen.tvgLazyGenDefault,
4913 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004914 ),
4915 "types": [DType.FP32],
4916 "error_if_validators": (
4917 TosaErrorValidator.evWrongInputType,
4918 TosaErrorValidator.evWrongOutputType,
4919 TosaErrorValidator.evWrongInputList,
4920 TosaErrorValidator.evWrongOutputList,
4921 TosaErrorValidator.evWrongRank,
4922 TosaErrorValidator.evBatchMismatch,
4923 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004924 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004925 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004926 "data_gen": {
4927 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4928 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004929 },
Won Jeon74342e52024-01-09 00:34:40 +00004930 # Shape
4931 "add_shape": {
4932 "op": Op.ADD_SHAPE,
4933 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004934 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004935 "build_fcn": (
4936 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004937 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004938 TosaTensorValuesGen.tvgAddSub,
4939 TosaArgGen.agNone,
4940 ),
4941 "types": [DType.SHAPE],
4942 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4943 },
4944 "sub_shape": {
4945 "op": Op.SUB_SHAPE,
4946 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004947 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004948 "build_fcn": (
4949 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004950 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004951 TosaTensorValuesGen.tvgAddSub,
4952 TosaArgGen.agNone,
4953 ),
4954 "types": [DType.SHAPE],
4955 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4956 },
4957 "mul_shape": {
4958 "op": Op.MUL_SHAPE,
4959 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004960 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004961 "build_fcn": (
4962 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004963 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004964 TosaTensorValuesGen.tvgMul,
4965 TosaArgGen.agNone,
4966 ),
4967 "types": [DType.SHAPE],
4968 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4969 },
4970 "div_shape": {
4971 "op": Op.DIV_SHAPE,
4972 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004973 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004974 "build_fcn": (
4975 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004976 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004977 TosaTensorValuesGen.tvgIntDiv,
4978 TosaArgGen.agNone,
4979 ),
4980 "types": [DType.SHAPE],
4981 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4982 },
4983 "concat_shape": {
4984 "op": Op.CONCAT_SHAPE,
4985 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004986 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004987 "build_fcn": (
4988 build_concat,
4989 TosaTensorGen.tgConcat,
4990 TosaTensorValuesGen.tvgConcat,
4991 TosaArgGen.agNone,
4992 ),
4993 "types": [DType.SHAPE],
4994 "error_if_validators": (),
4995 },
4996 "const_shape": {
4997 "op": Op.CONST_SHAPE,
4998 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004999 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005000 "build_fcn": (
5001 build_const,
5002 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005003 TosaTensorValuesGen.tvgLazyGenDefault,
5004 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005005 ),
5006 "types": [DType.SHAPE],
5007 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005008 }
5009
Kevin Cheng550ccc52021-03-03 11:21:43 -08005010
Eric Kunzee5e26762020-10-13 16:11:07 -07005011class OutputShaper:
5012 # Methods in this class compute the expected output shape and datatype
5013 # for common classes of operations
5014 def __init__(self):
5015 pass
5016
5017 # These methods return arguments that can be used for
5018 # creating a new output tensor
5019 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005020 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5021 if error_name != ErrorIf.RankMismatch:
5022 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005023 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005024
5025 shape = []
5026 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005027 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005028 shape.append(b.shape[i])
5029 else:
5030 shape.append(a.shape[i])
5031
Jerry Ge135c9552023-05-23 20:59:32 +00005032 fuzz_idx = rng.integers(0, len(a.shape))
5033 if error_name == ErrorIf.DimensionMismatch:
5034 shape[fuzz_idx] += 1
5035
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005036 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005037 all_dtypes = [
5038 DType.INT8,
5039 DType.INT16,
5040 DType.INT32,
5041 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005042 DType.FP16,
5043 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005044 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005045 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005046 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5047 outputDType = rng.choice(wrong_dtypes)
5048 else:
5049 outputDType = a.dtype
5050
5051 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005052
5053 @staticmethod
5054 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005055 assert len(a.shape) == len(b.shape)
5056 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005057
5058 shape = []
5059 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005060 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005061 shape.append(a.shape[i])
5062
Kevin Cheng550ccc52021-03-03 11:21:43 -08005063 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005064
5065 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005066 def unaryOp(ser, rng, a, error_name=None):
5067 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005068 all_dtypes = [
5069 DType.INT8,
5070 DType.INT16,
5071 DType.INT32,
5072 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005073 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005074 DType.FP16,
5075 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005076 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005077 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5078 outputDType = rng.choice(wrong_dtypes)
5079 else:
5080 outputDType = a.dtype
5081
5082 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005083
5084 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005085 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005086 if error_name != ErrorIf.RankMismatch:
5087 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005088 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005089
5090 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005091 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005092 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005093 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5094 else:
5095 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005096
Jerry Ge135c9552023-05-23 20:59:32 +00005097 fuzz_idx = rng.integers(0, len(a.shape))
5098 if error_name == ErrorIf.DimensionMismatch:
5099 shape[fuzz_idx] += 1
5100
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005101 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005102 all_dtypes = [
5103 DType.INT8,
5104 DType.INT16,
5105 DType.INT32,
5106 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005107 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005108 DType.FP16,
5109 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005110 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005111 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5112 outputDType = rng.choice(wrong_dtypes)
5113 else:
5114 outputDType = a.dtype
5115
5116 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005117
5118 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005119 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005120 if error_name != ErrorIf.RankMismatch:
5121 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005122 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005123
5124 # Do broadcast
5125 shape = []
5126 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005127 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005128 shape.append(b.shape[i])
5129 else:
5130 shape.append(a.shape[i])
5131
Jerry Ge135c9552023-05-23 20:59:32 +00005132 fuzz_idx = rng.integers(0, len(a.shape))
5133 if error_name == ErrorIf.DimensionMismatch:
5134 shape[fuzz_idx] += 1
5135
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005136 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005137 wrong_dtypes = [
5138 DType.INT8,
5139 DType.INT16,
5140 DType.INT32,
5141 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005142 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005143 DType.FP16,
5144 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005145 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005146 outputDType = rng.choice(wrong_dtypes)
5147 else:
5148 outputDType = DType.BOOL
5149
5150 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005151
5152 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005153 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005154 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005155 if error_name not in [
5156 ErrorIf.AxisSmallerZero,
5157 ErrorIf.AxisLargerRank,
5158 ErrorIf.ShapeOfAxisNotOne,
5159 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005160 shape[axis] = 1
5161 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5162 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005163
Matthew Haddond6ce7252021-09-29 15:35:44 +01005164 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005165 all_dtypes = [
5166 DType.INT8,
5167 DType.INT16,
5168 DType.INT32,
5169 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005170 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005171 DType.FP16,
5172 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005173 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005174 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5175 outputDType = rng.choice(wrong_dtypes)
5176 else:
5177 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005178
Matthew Haddond6ce7252021-09-29 15:35:44 +01005179 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005180
5181 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005182 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005183 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005184
5185 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5186 del shape[axis]
5187
5188 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5189 remove = rng.choice([True, False])
5190 if remove and len(shape) > 1:
5191 del shape[0]
5192 else:
5193 shape.append(1)
5194 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5195 for i in range(len(shape)):
5196 shape[i] = shape[i] + rng.integers(1, 10)
5197
5198 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005199 all_dtypes = [
5200 DType.INT8,
5201 DType.INT16,
5202 DType.INT32,
5203 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005204 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005205 DType.FP16,
5206 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005207 DType.FP8E4M3,
5208 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005209 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005210 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5211 outputDType = rng.choice(wrong_dtypes)
5212 else:
5213 outputDType = DType.INT32
5214
5215 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005216
5217 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005218 def conv2dOp(
5219 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5220 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005221
5222 # IFM: NHWC
5223 # Filter: OHWI
5224 # OFM: NHWC
5225
Kevin Cheng550ccc52021-03-03 11:21:43 -08005226 h = (
5227 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005228 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005229 + padding[0]
5230 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005231 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005232 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005233
Kevin Cheng550ccc52021-03-03 11:21:43 -08005234 w = (
5235 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005236 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005237 + padding[2]
5238 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005239 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005240 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005241
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005242 if error_name == ErrorIf.ConvOutputShapeMismatch:
5243 choices = [1, 2, 3]
5244 change = rng.choice(choices)
5245 # increment in multiples of stride to not hit non-integer error case
5246 if change in [1, 3]:
5247 h = h + (rng.choice(choices) * strides[0])
5248 if change in [2, 3]:
5249 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005250
Eric Kunzee5e26762020-10-13 16:11:07 -07005251 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5252
James Ward8b390432022-08-12 20:48:56 +01005253 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005254 # Pick some potentially correct output dtype if input type is incorrect
5255 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005256 else:
James Ward8b390432022-08-12 20:48:56 +01005257 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005258
5259 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005260 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005261 excludes = [DType.FP16, DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00005262 if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
5263 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005264 else:
5265 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005266 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005267 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005268
Kevin Cheng550ccc52021-03-03 11:21:43 -08005269 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005270
5271 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005272 def conv3dOp(
5273 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5274 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005275
5276 # IFM: NDHWC
5277 # Filter: ODHWI
5278 # OFM: NDHWC
5279
5280 d = (
5281 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005282 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005283 + padding[0]
5284 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005285 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005286 ) // strides[0] + 1
5287
5288 h = (
5289 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005290 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005291 + padding[2]
5292 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005293 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005294 ) // strides[1] + 1
5295
5296 w = (
5297 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005298 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005299 + padding[4]
5300 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005301 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005302 ) // strides[2] + 1
5303
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005304 if error_name == ErrorIf.ConvOutputShapeMismatch:
5305 choices = [1, 2, 3, 4]
5306 change = rng.choice(choices)
5307 # increment in multiples of stride to not hit non-integer error case
5308 if change in [1, 4]:
5309 d = d + (rng.choice(choices) * strides[0])
5310 if change in [2, 4]:
5311 h = h + (rng.choice(choices) * strides[1])
5312 if change in [3, 4]:
5313 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005314
Kevin Cheng1533b852021-09-01 12:51:58 -07005315 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5316
James Ward8b390432022-08-12 20:48:56 +01005317 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005318 # Pick some potentially correct output dtype if input type is incorrect
5319 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005320 else:
James Ward8b390432022-08-12 20:48:56 +01005321 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005322
5323 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005324 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005325 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005326 else:
5327 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005328 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005329 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005330
5331 return ser.addOutput(ofm_shape, out_dtype)
5332
5333 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005334 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005335 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005336 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005337 # IFM: NHWC
5338 # Filter: HWCM
5339 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005340
Kevin Cheng550ccc52021-03-03 11:21:43 -08005341 h = (
5342 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005343 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005344 + padding[0]
5345 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005346 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005347 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005348
Kevin Cheng550ccc52021-03-03 11:21:43 -08005349 w = (
5350 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005351 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005352 + padding[2]
5353 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005354 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005355 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005356
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005357 if error_name == ErrorIf.ConvOutputShapeMismatch:
5358 choices = [1, 2, 3]
5359 change = rng.choice(choices)
5360 # increment in multiples of stride to not hit non-integer error case
5361 if change in [1, 3]:
5362 h = h + (rng.choice(choices) * strides[0])
5363 if change in [2, 3]:
5364 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005365
Eric Kunzee5e26762020-10-13 16:11:07 -07005366 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5367
James Ward8b390432022-08-12 20:48:56 +01005368 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005369 # Pick some potentially correct output dtype if input type is incorrect
5370 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005371 else:
James Ward8b390432022-08-12 20:48:56 +01005372 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005373
5374 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005375 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005376 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005377 else:
5378 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005379 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005380 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005381
Kevin Cheng550ccc52021-03-03 11:21:43 -08005382 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005383
5384 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005385 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005386 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005387 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005388 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005389 h = 1
5390 w = 1
5391 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005392 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5393 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005394
5395 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005396 choices = [1, 2, 3]
5397 change = rng.choice(choices)
5398 # increment in multiples of stride to not hit non-integer error case
5399 if change in [1, 3]:
5400 h = h + (rng.choice(choices) * stride[0])
5401 if change in [2, 3]:
5402 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005403 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005404
5405 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005406 all_dtypes = [
5407 DType.INT8,
5408 DType.INT16,
5409 DType.INT32,
5410 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005411 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005412 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005413 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005414 DType.FP8E4M3,
5415 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005416 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005417 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5418 outputDType = rng.choice(wrong_dtypes)
5419 else:
5420 outputDType = ifm.dtype
5421
5422 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005423
5424 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005425 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005426 # input: N, IC
5427 # filter: OC, IC
5428 # output: N, OC
5429
5430 output_shape = [input.shape[0], filter.shape[0]]
5431
James Ward8b390432022-08-12 20:48:56 +01005432 # Validated in arg_gen (also invalidated for ErrorIf)
5433 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005434
Kevin Cheng550ccc52021-03-03 11:21:43 -08005435 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005436
5437 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005438 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005439 # a: N, H, C
5440 # b: N, C, W
5441 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005442
Kevin Cheng2d60f002021-06-09 14:18:32 -07005443 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005444
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005445 if error_name == ErrorIf.WrongOutputType:
5446 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005447 incorrect_types = (
5448 DType.INT4,
5449 DType.INT8,
5450 DType.INT16,
5451 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005452 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005453 DType.FP16,
5454 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005455 DType.FP8E4M3,
5456 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005457 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005458 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005459 incorrect_types = (
5460 DType.INT4,
5461 DType.INT8,
5462 DType.INT16,
5463 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005464 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005465 DType.FP16,
5466 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005467 DType.FP8E4M3,
5468 DType.FP8E5M2,
5469 )
5470 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5471 incorrect_types = (
5472 DType.INT4,
5473 DType.INT8,
5474 DType.INT16,
5475 DType.INT32,
5476 DType.INT48,
5477 DType.FP32,
5478 DType.BF16,
5479 DType.FP8E4M3,
5480 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005481 )
James Ward24dbc422022-10-19 12:20:31 +01005482 elif (
5483 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5484 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005485 incorrect_types = (
5486 DType.INT4,
5487 DType.INT8,
5488 DType.INT16,
5489 DType.INT32,
5490 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005491 DType.FP8E4M3,
5492 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005493 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005494 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005495 elif error_name == ErrorIf.WrongInputType:
5496 # Pick some potentially correct output dtype if input type is incorrect
5497 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005498 else:
James Ward8b390432022-08-12 20:48:56 +01005499 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005500
Kevin Cheng550ccc52021-03-03 11:21:43 -08005501 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005502
5503 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005504 def concatOp(ser, rng, axis, inputs, error_name=None):
5505 input1 = inputs[0]
5506 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005507
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005508 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005509 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005510 if not (
5511 # unable to concat tensors of different ranks
5512 error_name == ErrorIf.ConcatInputRankMismatch
5513 # unable to concat tensors along an invalid axis
5514 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005515 ):
5516 for tensor in remaining_inputs:
5517 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005518
Matthew Haddon01c359d2021-10-15 16:30:48 +01005519 if error_name == ErrorIf.ConcatShapeSumMismatch:
5520 output_shape[axis] += rng.integers(5, 10)
5521
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005522 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005523 all_dtypes = {
5524 DType.INT8,
5525 DType.INT16,
5526 DType.INT32,
5527 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005528 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005529 DType.FP16,
5530 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005531 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005532 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5533 outputDType = rng.choice(wrong_dtypes)
5534 else:
5535 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005536
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005537 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005538
5539 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005540 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005541
5542 output_shape = a.shape.copy()
5543
5544 for i in range(len(output_shape)):
5545 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5546
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005547 if error_name == ErrorIf.PadOutputShapeMismatch:
5548 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005549 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005550 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005551 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005552
Matthew Haddone807aae2021-10-11 18:12:58 +01005553 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005554 all_dtypes = [
5555 DType.INT8,
5556 DType.INT16,
5557 DType.INT32,
5558 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005559 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005560 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005561 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005562 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005563 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5564 outputDType = rng.choice(wrong_dtypes)
5565 else:
5566 outputDType = a.dtype
5567
5568 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005569
5570 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005571 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005572 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005573
5574 if error_name == ErrorIf.WrongOutputType:
5575 all_dtypes = [
5576 DType.INT8,
5577 DType.INT16,
5578 DType.INT32,
5579 DType.INT48,
5580 DType.FP32,
5581 DType.FP16,
5582 DType.BF16,
5583 ]
5584 wrong_dtypes = list(set(all_dtypes))
5585 outputDType = rng.choice(wrong_dtypes)
5586 else:
5587 outputDType = DType.SHAPE
5588
5589 return ser.addOutput(output_shape, outputDType)
5590
5591 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005592 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005593 output_shape = shape.copy()
5594
Matthew Haddone807aae2021-10-11 18:12:58 +01005595 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5596 for i in range(len(output_shape)):
5597 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5598
5599 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005600 all_dtypes = [
5601 DType.INT8,
5602 DType.INT16,
5603 DType.INT32,
5604 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005605 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005606 DType.FP16,
5607 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005608 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005609 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5610 outputDType = rng.choice(wrong_dtypes)
5611 else:
5612 outputDType = a.dtype
5613
5614 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005615
5616 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005617 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005618
Matthew Haddone807aae2021-10-11 18:12:58 +01005619 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005620 all_dtypes = [
5621 DType.INT8,
5622 DType.INT16,
5623 DType.INT32,
5624 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005625 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005626 DType.FP16,
5627 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005628 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005629 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005630 outputDType = rng.choice(wrong_dtypes)
5631 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005632 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005633
Luke Huttona4e48ca2023-02-22 11:53:48 +00005634 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005635 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005636 for index in range(len(output_shape)):
5637 if output_shape[index] <= 2:
5638 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5639 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005640 output_shape[index] = output_shape[index] + rng.choice(
5641 [-2, -1, 1, 2]
5642 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005643 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5644 output_shape = input.shape.copy()
5645 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005646 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005647
5648 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005649
5650 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005651 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005652
5653 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005654 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005655
5656 for i in range(len(output_shape)):
5657 output_shape[i] = a.shape[i] * multiples[i]
5658
Luke Huttona4e48ca2023-02-22 11:53:48 +00005659 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005660 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005661
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005662 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005663 all_dtypes = [
5664 DType.INT8,
5665 DType.INT16,
5666 DType.INT32,
5667 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005668 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005669 DType.FP16,
5670 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005671 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005672 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5673 outputDType = rng.choice(wrong_dtypes)
5674 else:
5675 outputDType = a.dtype
5676
5677 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005678
5679 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005680 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005681 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005682
Kevin Cheng550ccc52021-03-03 11:21:43 -08005683 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005684
Luke Huttona4e48ca2023-02-22 11:53:48 +00005685 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005686 for i in range(len(output_shape)):
5687 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005688
Luke Huttona4e48ca2023-02-22 11:53:48 +00005689 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5690 for i in range(len(output_shape)):
5691 output_shape[i] += rng.integers(1, 10)
5692 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005693 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005694
Matthew Haddone807aae2021-10-11 18:12:58 +01005695 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005696 all_dtypes = [
5697 DType.INT8,
5698 DType.INT16,
5699 DType.INT32,
5700 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005701 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005702 DType.FP16,
5703 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005704 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005705 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5706 outputDType = rng.choice(wrong_dtypes)
5707 else:
5708 outputDType = a.dtype
5709
5710 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005711
5712 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005713 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005714 if error_name != ErrorIf.WrongRank:
5715 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005716 assert len(indices.shape) == 2
5717 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005718
Kevin Cheng77d0f762020-11-24 10:26:32 -08005719 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5720
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005721 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005722 all_dtypes = [
5723 DType.INT8,
5724 DType.INT16,
5725 DType.INT32,
5726 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005727 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005728 DType.FP16,
5729 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005730 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005731 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5732 outputDType = rng.choice(wrong_dtypes)
5733 else:
5734 outputDType = values.dtype
5735
5736 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005737
5738 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005739 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005740 if error_name != ErrorIf.WrongRank:
5741 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005742 assert len(indices.shape) == 2
5743 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005744 assert values_in.shape[0] == indices.shape[0] # N
5745 assert input.shape[1] == indices.shape[1] # W
5746 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005747
5748 output_shape = values_in.shape
5749
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005750 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005751 all_dtypes = [
5752 DType.INT8,
5753 DType.INT16,
5754 DType.INT32,
5755 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005756 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005757 DType.FP16,
5758 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005759 DType.FP8E4M3,
5760 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005761 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005762 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5763 outputDType = rng.choice(wrong_dtypes)
5764 else:
5765 outputDType = values_in.dtype
5766
5767 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005768
5769 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005770 def tableOp(ser, rng, input, error_name=None):
5771 # Same shape as the input, dtype dependent on input dtype
5772 if error_name != ErrorIf.WrongInputType:
5773 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005774 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005775 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005776 wrong_dtypes = [
5777 DType.INT8,
5778 DType.INT16,
5779 DType.INT32,
5780 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005781 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005782 DType.FP16,
5783 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005784 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005785 wrong_dtypes.remove(output_dtype)
5786 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005787 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005788
5789 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005790 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005791 serializer,
5792 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005793 input,
5794 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005795 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005796 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005797 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005798 input_dtype,
5799 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005800 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005801 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005802 # Calculate OH, OW
5803 scale_y_n = scale[0]
5804 scale_y_d = scale[1]
5805 scale_x_n = scale[2]
5806 scale_x_d = scale[3]
5807 if error_name == ErrorIf.ScaleSmallerEqualZero:
5808 scale_y_n = max(scale_y_n, 1)
5809 scale_y_d = max(scale_y_d, 1)
5810 scale_x_n = max(scale_x_n, 1)
5811 scale_x_d = max(scale_x_d, 1)
5812
5813 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5814 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5815
5816 if error_name is not None:
5817 # Make sure the output tensor is valid, which can occur when
5818 # scale, offset or border have been changed for ERROR_IFs
5819 oh = max(oh, 1)
5820 ow = max(ow, 1)
5821 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005822 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5823 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005824
5825 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5826 choices = [1, 2, 3]
5827 change = rng.choice(choices)
5828 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5829 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005830 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005831 oh -= scale_y_d
5832 assert oh > 0 # Should have been caught in agResize
5833 else:
5834 oh += scale_y_d
5835 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005836 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005837 ow -= scale_x_d
5838 assert ow > 0 # Should have been caught in agResize
5839 else:
5840 ow += scale_x_d
5841
Matthew Haddon848efb42021-09-09 12:30:53 +01005842 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005843 output_dims = [
5844 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005845 oh,
5846 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005847 input.shape[0],
5848 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005849 elif error_name == ErrorIf.BatchMismatch:
5850 output_dims = [
5851 input.shape[0] + rng.integers(1, 10),
5852 oh,
5853 ow,
5854 input.shape[3],
5855 ]
5856 elif error_name == ErrorIf.ChannelMismatch:
5857 output_dims = [
5858 input.shape[0],
5859 oh,
5860 ow,
5861 input.shape[3] + rng.integers(1, 10),
5862 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005863 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005864 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005865
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005866 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005867
5868 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005869 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005870 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005871
5872 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005873 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005874 if error_name == ErrorIf.ConvOutputShapeMismatch:
5875 choices = [1, 2, 3]
5876 change = rng.choice(choices)
5877 if change in [1, 3]:
5878 output_shape[1] = output_shape[1] + rng.choice(choices)
5879 if change in [2, 3]:
5880 output_shape[2] = output_shape[2] + rng.choice(choices)
5881
James Ward8b390432022-08-12 20:48:56 +01005882 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005883 # Pick some potentially correct output dtype if input type is incorrect
5884 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005885 else:
James Ward8b390432022-08-12 20:48:56 +01005886 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005887
5888 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005889 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005890 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005891 else:
5892 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005893 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005894 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005895
Kevin Cheng550ccc52021-03-03 11:21:43 -08005896 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005897
5898 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005899 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5900 outputs = []
5901
5902 assert ifm1.dtype == ifm2.dtype
5903 input_dtype = ifm1.dtype
5904
5905 if error_name != ErrorIf.FFTInputShapeMismatch:
5906 assert ifm1.shape == ifm2.shape
5907
5908 input_shape = ifm1.shape
5909 if error_name != ErrorIf.WrongRank:
5910 assert len(input_shape) == 3
5911
5912 output_shape = input_shape.copy()
5913 output_dtype = input_dtype
5914
5915 if error_name == ErrorIf.WrongOutputType:
5916 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005917 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005918 output_dtype = rng.choice(wrong_dtypes)
5919 elif error_name == ErrorIf.BatchMismatch:
5920 output_shape[0] += rng.integers(1, 10)
5921 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5922 modify_dim = rng.choice([1, 2])
5923 output_shape[modify_dim] += rng.integers(1, 10)
5924
5925 outputs.append(serializer.addOutput(output_shape, output_dtype))
5926 outputs.append(serializer.addOutput(output_shape, output_dtype))
5927 return outputs
5928
5929 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005930 def rfft2dOp(serializer, rng, value, error_name=None):
5931 outputs = []
5932
5933 input_shape = value.shape
5934 if error_name != ErrorIf.WrongRank:
5935 assert len(input_shape) == 3
5936
5937 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5938
5939 output_dtype = value.dtype
5940 if error_name == ErrorIf.WrongOutputType:
5941 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005942 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005943 output_dtype = rng.choice(wrong_dtypes)
5944 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005945 output_shape[0] += rng.integers(1, 10)
5946 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5947 modify_dim = rng.choice([1, 2])
5948 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005949
5950 outputs.append(serializer.addOutput(output_shape, output_dtype))
5951 outputs.append(serializer.addOutput(output_shape, output_dtype))
5952 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005953
5954 @staticmethod
5955 def addShapeOp(ser, rng, a, b, error_name=None):
5956 if error_name != ErrorIf.RankMismatch:
5957 assert len(a.shape) == len(b.shape)
5958 assert a.dtype == b.dtype
5959
5960 shape = []
5961 for i in range(len(a.shape)):
5962 shape.append(a.shape[i])
5963
5964 fuzz_idx = rng.integers(0, len(a.shape))
5965 if error_name == ErrorIf.DimensionMismatch:
5966 shape[fuzz_idx] += 1
5967
5968 if error_name == ErrorIf.WrongOutputType:
5969 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5970 outputDType = rng.choice(wrong_dtypes)
5971 else:
5972 outputDType = DType.SHAPE
5973 return ser.addOutput(shape, outputDType)