blob: 3173906db4b62fd111048c55efc56f7d4d7b0c7f [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Tai Ly60dc48c2024-03-08 22:19:41 +00006import struct
Matthew Haddon630c17c2021-10-14 15:05:41 +01007from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01008from datetime import datetime
9from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -070010
Jeremy Johnson1271c442023-09-05 11:39:26 +010011import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000012import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000013import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010014from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010015from generator.tosa_arg_gen import TosaArgGen
16from generator.tosa_arg_gen import TosaQuantGen
17from generator.tosa_arg_gen import TosaTensorGen
18from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000019from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_error_if import TosaErrorIfArgGen
21from generator.tosa_error_if import TosaErrorValidator
22from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010023from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000024from tosa.DType import DType
25from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010026
Jeremy Johnson1271c442023-09-05 11:39:26 +010027TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
28// SPDX-License-Identifier: Apache-2.0
29// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
30"""
31
Jeremy Johnsonaf090182024-02-13 18:25:39 +000032logging.basicConfig()
33logger = logging.getLogger("tosa_verif_build_tests")
34
Matthew Haddonb724efc2021-08-25 16:40:29 +010035
Eric Kunzee5e26762020-10-13 16:11:07 -070036class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010037 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000038 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010039 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010040 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010041 TOSA_8K_LEVEL_MAX_KERNEL = 8192
42 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010043
Jeremy Johnson1271c442023-09-05 11:39:26 +010044 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000045 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010046 TOSA_MI_DOT_PRODUCT_MIN = 1000
47
Eric Kunzee5e26762020-10-13 16:11:07 -070048 def __init__(self, args):
49 self.args = args
50 self.basePath = args.output_dir
51 self.random_seed = args.random_seed
52 self.ser = None
53 self.rng = np.random.default_rng(self.random_seed)
54 self.createDynamicOpLists()
55 self.initOpListDefaults()
56 self.quantGen = TosaQuantGen()
57 # Force makeShape to do a specific starting shape
58 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010059 # JSON schema validation
60 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010061 # Data generator library is sometimes needed for compliance set up
62 # even if we are generating the data later (lazy_data_generation)
63 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070064
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010065 # Work out floating point range
66 def convertFPRange(rangeFP, maxFP):
67 # Converts program arguments of max/-max to FP max
68 vals = []
69 for v in rangeFP:
70 if v == "max":
71 v = maxFP
72 elif v == "-max":
73 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000074 elif v < 0:
75 # Trim to minimum data type value
76 v = max(v, -maxFP)
77 elif v > 0:
78 # Trim to maximum data type value
79 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 vals.append(v)
81 return tuple(sorted(vals))
82
83 self.random_float_range = {}
Won Jeon2c34b462024-02-06 18:37:00 +000084 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010085 self.random_float_range[dtype] = convertFPRange(
86 args.tensor_fp_value_range,
87 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
88 )
89
Eric Kunzee5e26762020-10-13 16:11:07 -070090 def createSerializer(self, opName, testPath):
91 self.testPath = os.path.join(opName, testPath)
92
93 fullPath = os.path.join(self.basePath, self.testPath)
94 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010095 # Embed const data in the flatbuffer
96 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010097 if self.args.lazy_data_gen:
98 # Lazy data generation - so make constants files
99 constMode = ts.ConstMode.INPUTS
100 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100101 constMode = ts.ConstMode.EMBED_DUMP
102 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700103
104 def getSerializer(self):
105 return self.ser
106
Jeremy Johnson1271c442023-09-05 11:39:26 +0100107 def serialize(self, testName, metaData=None):
108 path = Path(self.basePath) / self.testPath
109
110 # Write out TOSA flatbuffer binary
111 path_fb = path / f"{testName}.tosa"
112 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 fd.write(self.ser.serialize())
114
Jeremy Johnson1271c442023-09-05 11:39:26 +0100115 # Get JSON descriptor from serializer
116 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
117
118 if metaData:
119 # Add extra meta data to desc.json
120 desc["meta"] = metaData
121
122 # Validate desc.json before we output it
123 self.descSchemaValidator.validate_config(desc)
124
125 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100126 if "data_gen" in metaData:
127 if self.args.lazy_data_gen:
128 # Output datagen meta data as CPP data
129 path_md = path / f"{testName}_meta_data_gen.cpp"
130 with path_md.open("w") as fd:
131 fd.write(TOSA_AUTOGENERATED_HEADER)
132 fd.write("// Test meta data for data generation setup\n\n")
133 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
134 json.dump(metaData["data_gen"], fd)
135 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100136 if "compliance" in metaData:
137 # Output datagen meta data as CPP data
138 path_md = path / f"{testName}_meta_compliance.cpp"
139 with path_md.open("w") as fd:
140 fd.write(TOSA_AUTOGENERATED_HEADER)
141 fd.write("// Test meta data for compliance validation\n\n")
142 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
143 json.dump(metaData["compliance"], fd)
144 fd.write(')";\n\n')
145
146 # Write desc.json
147 path_desc = path / "desc.json"
148 with path_desc.open("w") as fd:
149 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700150
Matthew Haddon74567092021-07-16 15:38:20 +0100151 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000152 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100153 seed = self.random_seed + 1
154 self.rng = np.random.default_rng(seed)
155
Jeremy Johnson1271c442023-09-05 11:39:26 +0100156 def getDTypeRange(self, dtype, high_inclusive=False):
157 # Returns dtype value range boundaries (low, high)
158 # The high boundary is excluded in the range
159 # unless high_inclusive is True
Won Jeon2c34b462024-02-06 18:37:00 +0000160 if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100161 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100162 elif dtype == DType.BOOL:
163 rng = (0, 2)
164 elif dtype == DType.UINT8:
165 rng = (0, 256)
166 elif dtype == DType.UINT16:
167 rng = (0, 65536)
168 elif dtype == DType.INT4:
169 # TOSA specific INT4 weight range from -7 to 7
170 rng = (-7, 8)
171 elif dtype == DType.INT8:
172 rng = (-128, 128)
173 elif dtype == DType.INT16:
174 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000175 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100176 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000177 elif dtype == DType.SHAPE:
178 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100179 elif dtype == DType.INT48:
180 rng = (-(1 << 47), (1 << 47))
181 else:
182 raise Exception("Unknown dtype: {}".format(dtype))
183
184 if not high_inclusive:
185 # Exclusive high: low <= range < high
186 return rng
187 else:
188 # Inclusive range: low <= range <= high
189 return (rng[0], rng[1] - 1)
190
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000191 def getRandTensor(self, shape, dtype, data_range=None):
192 if data_range is None:
193 low, high = self.getDTypeRange(dtype)
194 else:
195 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100196
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700198 return np.bool_(self.rng.choice(a=[False, True], size=shape))
evacha011adff832024-03-06 17:33:44 +0000199 elif dtype == DType.INT4:
200 return np.int8(self.rng.integers(low=low, high=high, size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000201 elif dtype == DType.INT8:
202 return np.int8(self.rng.integers(low=low, high=high, size=shape))
203 elif dtype == DType.UINT8:
204 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Jerry Ge20ab3df2024-01-26 16:56:55 +0000205 elif dtype == DType.INT16:
206 return np.int16(self.rng.integers(low=low, high=high, size=shape))
207 elif dtype == DType.UINT16:
208 return np.uint16(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000209 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100210 return np.int64(self.rng.integers(low=low, high=high, size=shape))
Won Jeon2c34b462024-02-06 18:37:00 +0000211 elif dtype in (
212 DType.FP16,
213 DType.BF16,
214 DType.FP32,
215 DType.FP8E4M3,
216 DType.FP8E5M2,
217 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100218 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
219
220 if dtype == DType.FP16:
221 return np.float16(f_tensor)
222 else:
223 f32_tensor = np.float32(f_tensor)
224 if dtype == DType.BF16:
225 # Floor the last 16 bits of each f32 value
226 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
Won Jeon2c34b462024-02-06 18:37:00 +0000227 elif dtype == DType.FP8E4M3:
228 return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
229 elif dtype == DType.FP8E5M2:
230 return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
Jeremy Johnson1271c442023-09-05 11:39:26 +0100231 else:
232 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700233 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100234 # All other integer types
235 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700236
Kevin Cheng989cb052021-04-28 16:29:44 -0700237 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700238 placeholders = []
239
Kevin Cheng989cb052021-04-28 16:29:44 -0700240 assert len(shape_list) == len(dtype_list)
241
Jeremy Johnson1271c442023-09-05 11:39:26 +0100242 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700243 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100244 if not self.args.lazy_data_gen:
245 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700246 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700247
248 return placeholders
249
Kevin Cheng989cb052021-04-28 16:29:44 -0700250 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700251 consts = []
252
Kevin Cheng989cb052021-04-28 16:29:44 -0700253 assert len(shape_list) == len(dtype_list)
254
Jeremy Johnson1271c442023-09-05 11:39:26 +0100255 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700256 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100257 if not self.args.lazy_data_gen:
258 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700259 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 return consts
262
263 def makeShape(self, rank):
264 if self.targetted_shape:
265 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800266 return np.int32(
267 self.rng.integers(
268 low=self.args.tensor_shape_range[0],
269 high=self.args.tensor_shape_range[1],
270 size=rank,
271 )
272 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700273
274 def setTargetShape(self, shape):
275 self.targetted_shape = shape
276
277 def randInt(self, low=0, high=256):
278 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
279
280 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100281 low, high = self.getDTypeRange(dtype)
282
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100283 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100284 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100285 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100286 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100287 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100288 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
289 return gtu.vect_f32_to_bf16(rand_f32)
Won Jeon2c34b462024-02-06 18:37:00 +0000290 elif dtype == DType.FP8E4M3:
291 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
292 return gtu.vect_f32_to_fp8e4m3(rand_f32)
293 elif dtype == DType.FP8E5M2:
294 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
295 return gtu.vect_f32_to_fp8e5m2(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700296 elif dtype == DType.BOOL:
297 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000298 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 # Special size
300 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700301
302 return np.int32(self.rng.integers(low, high, size=1))[0]
303
304 def shapeStr(self, shape):
305
306 sStr = []
307 # Convert to strings
308 for i in shape:
309 sStr.append(str(i))
310
Kevin Cheng550ccc52021-03-03 11:21:43 -0800311 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700312
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100313 def typeStr(self, dtype):
314 if isinstance(dtype, list) or isinstance(dtype, tuple):
315 assert len(dtype) >= 2
316 strs = [self.typeStr(t) for t in dtype]
317 # Limit types to the first 2 as the 3rd is the accumulator
318 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700319 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100320 if dtype in gtu.DTYPE_ATTRIBUTES:
321 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700322 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100323 raise Exception(
324 "Unknown dtype, cannot convert to string: {}".format(dtype)
325 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
Luke Hutton57287132023-02-06 14:54:18 +0000327 def constrictBatchSize(self, shape):
328 # Limit the batch size unless an explicit target shape set
329 if self.args.max_batch_size and not self.args.target_shapes:
330 shape[0] = min(shape[0], self.args.max_batch_size)
331 return shape
332
James Ward30124a82023-02-02 14:56:33 +0000333 def makeDimension(self):
334 return self.randInt(
335 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
336 )
337
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100338 def tensorComplianceMetaData(
339 self, op, inputType, argsDict, outputTensor, errorName
340 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000341 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
342 UNSUPPORTED_NON_FP32_INPUT_OPS = (
343 Op.MATMUL,
344 Op.CONV2D,
345 Op.FULLY_CONNECTED,
346 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000347 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000348 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000349 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100350 if (
351 errorName
352 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000353 or (
354 not gtu.dtypeIsSupportedByCompliance(inputType)
355 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
356 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100357 ):
358 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100359 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100360
Jeremy Johnson1271c442023-09-05 11:39:26 +0100361 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100362 compliance_tens = {
363 "mode": None,
364 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
365 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
366 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100367 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
368 mode = gtu.ComplianceMode.DOT_PRODUCT
369 compliance_tens["dot_product_info"] = {
370 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100371 "ks": int(argsDict["ksb"])
372 if "ksb" in argsDict
373 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100374 }
evacha019c96eef2024-02-07 11:21:55 +0000375 elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100376 mode = gtu.ComplianceMode.FP_SPECIAL
377 elif "compliance" in op and "ulp" in op["compliance"]:
378 mode = gtu.ComplianceMode.ULP
379 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000380 elif "compliance" in op and "relative" in op["compliance"]:
381 mode = gtu.ComplianceMode.RELATIVE
382 compliance_tens["relative_info"] = {
383 "max": argsDict["max_abs_value"],
384 "scale": op["compliance"]["relative"],
385 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100386 elif op["op"] == Op.REDUCE_PRODUCT:
387 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000388 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000389 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000390 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000391 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
392 compliance_tens["abs_error_info"] = {
393 "lower_bound": op["compliance"]["abs_error_lower_bound"]
394 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800395 elif op["op"] in (Op.SIN, Op.COS):
396 mode = gtu.ComplianceMode.ABS_ERROR
397 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
398 compliance_tens["abs_error_info"] = {
399 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
400 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100401 else:
402 mode = gtu.ComplianceMode.EXACT
403 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
404
405 return compliance_tens
406
407 # Build Op functions
408 # Create the output tensor (calling OutputShaper as needed)
409 # Do final tweaks to attributes (if necessary for errorIf)
410 # Add Op into graph
411 # Return resulting tensor information or BuildInfo
412
413 class BuildInfo:
414 """Enhanced build information containing result tensor and associated compliance dict."""
415
416 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000417 if isinstance(resultTensor, list):
418 assert complianceDict is None or isinstance(complianceDict, list)
419 self.resultTensorList = resultTensor
420 self.complianceDictList = complianceDict
421 else:
422 self.resultTensorList = [resultTensor]
423 if complianceDict is None:
424 self.complianceDictList = None
425 else:
426 self.complianceDictList = [complianceDict]
427
428 def getComplianceInfo(self):
429 if self.complianceDictList is None:
430 return None
431 else:
432 tens_dict = {}
433 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
434 if comp is not None:
435 tens_dict[tens.name] = comp
436
437 if tens_dict:
438 # Have some compliance data, so return the info
439 compliance = {
440 "version": "0.1",
441 "tensors": tens_dict,
442 }
443 else:
444 compliance = None
445 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700446
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000447 def build_unary(
448 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
449 ):
450 assert len(inputs) == 1
451 a = inputs[0]
452 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100453
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000454 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100455
456 # Ensure new output type has correct qinfo
457 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000458 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000459 qinfo = [
460 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000461 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000462 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100463
464 # Invalidate Input/Output list for error if checks.
465 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000466 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100467 pCount, cCount = op["operands"]
468 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000469 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
470 self, error_name, input_list, output_list
471 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100472
Les Bell729b0352021-11-24 10:28:21 +0000473 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100474 self.ser,
475 validator_fcns,
476 error_name,
477 op=op,
478 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000479 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000480 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000481 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100482 input_list=input_list,
483 output_list=output_list,
484 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000485 ):
486 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100487
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000488 attr = None
489 if op["op"] == Op.NEGATE:
490 attr = ts.TosaSerializerAttribute()
491 attr.NegateAttribute(qinfo[0], qinfo[1])
492
493 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000494
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000495 compliance = self.tensorComplianceMetaData(
496 op, a.dtype, args_dict, result_tensor, error_name
497 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000498 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700499
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000500 def build_binary_broadcast(
501 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
502 ):
503 assert len(inputs) == 2
504 a, b = inputs
505 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000506 self.ser, self.rng, a, b, error_name
507 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100508
509 # Invalidate Input/Output list for error if checks.
510 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000511 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100512 pCount, cCount = op["operands"]
513 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000514 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
515 self, error_name, input_list, output_list
516 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100517
Les Bell729b0352021-11-24 10:28:21 +0000518 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100519 self.ser,
520 validator_fcns,
521 error_name,
522 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000523 input1=a,
524 input2=b,
525 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000526 output_dtype=result_tensor.dtype,
527 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100528 input_list=input_list,
529 output_list=output_list,
530 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000531 ):
532 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100533
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000534 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000535
Jeremy Johnson9a758382023-11-07 16:27:35 +0000536 compliance = self.tensorComplianceMetaData(
537 op, a.dtype, args_dict, result_tensor, error_name
538 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000539
540 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700541
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700543 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700545 return result_tens
546
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000547 def build_arithmetic_right_shift(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000548 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000549 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000550 assert len(inputs) == 2
551 a, b = inputs
552 round = args_dict["round"]
553 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000554 self.ser, self.rng, a, b, error_name
555 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100556
557 # Invalidate Input/Output list for error if checks.
558 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000559 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100560 pCount, cCount = op["operands"]
561 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000562 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
563 self, error_name, input_list, output_list
564 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100565
Les Bell729b0352021-11-24 10:28:21 +0000566 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100567 self.ser,
568 validator_fcns,
569 error_name,
570 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000571 input1=a,
572 input2=b,
573 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000574 output_dtype=result_tensor.dtype,
575 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100576 input_list=input_list,
577 output_list=output_list,
578 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000579 ):
580 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800581
582 attr = ts.TosaSerializerAttribute()
583 attr.ArithmeticRightShiftAttribute(round)
584
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000585 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000586
587 compliance = self.tensorComplianceMetaData(
588 op, a.dtype, args_dict, result_tensor, error_name
589 )
590
591 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800592
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100593 def build_mul(
594 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
595 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000596 # Note that mul is binary operator but it has a shift value tensor
597 assert len(inputs) == 3
598 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100599
600 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000601 self.ser, self.rng, a, b, error_name
602 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700603
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100604 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100605 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100606 result_tensor.setDtype(DType.INT32)
607
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100608 if error_name == ErrorIf.WrongOutputType:
609 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
610 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100611 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100612
613 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000614 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100615 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100616 pCount, cCount = op["operands"]
617 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000618 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
619 self, error_name, input_list, output_list
620 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100621
Les Bell729b0352021-11-24 10:28:21 +0000622 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100623 self.ser,
624 validator_fcns,
625 error_name,
626 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000627 input1=a,
628 input2=b,
629 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100630 output_dtype=result_tensor.dtype,
631 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100632 input_list=input_list,
633 output_list=output_list,
634 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000635 ):
636 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700637
Jeremy Johnson0a042992024-02-28 13:20:05 +0000638 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100639
640 compliance = self.tensorComplianceMetaData(
641 op, a.dtype, args_dict, result_tensor, error_name
642 )
643
644 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700645
Jeremy Johnson587cc842024-02-08 11:45:44 +0000646 def build_table(
647 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
648 ):
649 assert len(inputs) == 1
650 a = inputs[0]
651 table = args_dict["table"]
652 result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700653
Kevin Chengfe392ce2021-10-18 21:51:55 +0000654 attr = ts.TosaSerializerAttribute()
655 attr.TableAttribute(table)
656
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100657 # Invalidate Input/Output list for error if checks.
658 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000659 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100660 pCount, cCount = op["operands"]
661 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000662 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
663 self, error_name, input_list, output_list
664 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100665
Les Bell729b0352021-11-24 10:28:21 +0000666 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100667 self.ser,
668 validator_fcns,
669 error_name,
670 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000671 input_shape=a.shape,
672 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000673 output_dtype=result_tensor.dtype,
674 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100675 input_list=input_list,
676 output_list=output_list,
677 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000678 ):
679 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100680
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000681 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700682
Jeremy Johnson587cc842024-02-08 11:45:44 +0000683 compliance = self.tensorComplianceMetaData(
684 op, a.dtype, args_dict, result_tensor, error_name
685 )
686
687 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700688
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000689 def build_select(
690 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
691 ):
692 assert len(inputs) == 3
693 cond, a, b = inputs
694
695 result_tensor = OutputShaper.selectOp(
696 self.ser, self.rng, cond, a, b, error_name
697 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100698
699 # Invalidate Input/Output list for error if checks.
700 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000701 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100702 pCount, cCount = op["operands"]
703 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000704 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
705 self, error_name, input_list, output_list
706 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100707
Les Bell729b0352021-11-24 10:28:21 +0000708 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100709 self.ser,
710 validator_fcns,
711 error_name,
712 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000713 input1=cond,
714 input2=a,
715 input3=b,
716 input_shape=a.shape,
717 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000718 output_dtype=result_tensor.dtype,
719 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100720 input_list=input_list,
721 output_list=output_list,
722 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000723 ):
724 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100725
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000726 self.ser.addOperator(
727 op["op"],
728 input_list,
729 output_list,
730 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000731 compliance = self.tensorComplianceMetaData(
732 op, a.dtype, args_dict, result_tensor, error_name
733 )
734
735 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700736
Jeremy Johnsona0150012023-11-15 15:52:06 +0000737 def build_comparison(
738 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
739 ):
740 assert len(inputs) == 2
741 a, b = inputs
742
743 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000744 self.ser, self.rng, a, b, error_name
745 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100746
747 # Invalidate Input/Output list for error if checks.
748 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000749 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100750 pCount, cCount = op["operands"]
751 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000752 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
753 self, error_name, input_list, output_list
754 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100755
Les Bell729b0352021-11-24 10:28:21 +0000756 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100757 self.ser,
758 validator_fcns,
759 error_name,
760 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000761 input1=a,
762 input2=b,
763 input_shape=a.shape,
764 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000765 output_shape=result_tensor.shape,
766 output_dtype=result_tensor.dtype,
767 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100768 input_list=input_list,
769 output_list=output_list,
770 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000771 ):
772 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100773
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000774 self.ser.addOperator(
775 op["op"],
776 input_list,
777 output_list,
778 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000779
780 compliance = self.tensorComplianceMetaData(
781 op, a.dtype, args_dict, result_tensor, error_name
782 )
783 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700784
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000785 def build_argmax(
786 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
787 ):
788 assert len(inputs) == 1
789 a = inputs[0]
790 axis = args_dict["axis"]
791 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100792
793 # Invalidate Input/Output list for error if checks.
794 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000795 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100796 pCount, cCount = op["operands"]
797 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000798 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
799 self, error_name, input_list, output_list
800 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100801
Les Bell729b0352021-11-24 10:28:21 +0000802 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100803 self.ser,
804 validator_fcns,
805 error_name,
806 op=op,
807 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000808 input_shape=a.shape,
809 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000810 output_shape=result_tensor.shape,
811 output_dtype=result_tensor.dtype,
812 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100813 input_list=input_list,
814 output_list=output_list,
815 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000816 ):
817 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700818
819 attr = ts.TosaSerializerAttribute()
820 attr.AxisAttribute(axis)
821
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000822 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000823
824 compliance = self.tensorComplianceMetaData(
825 op, inputs[0].dtype, args_dict, result_tensor, error_name
826 )
827 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700828
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000829 def build_pool2d(
830 self,
831 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100832 inputs,
833 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 validator_fcns=None,
835 error_name=None,
836 qinfo=None,
837 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100838 assert len(inputs) == 1
839 input = inputs[0]
840 # max_pool has no accum_dtype
841 accum_dtype = (
842 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
843 )
844 stride = args_dict["stride"]
845 pad = args_dict["pad"]
846 kernel = args_dict["kernel"]
847
Jeremy Johnson0601f802023-11-08 16:28:09 +0000848 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000849 self.ser, self.rng, input, kernel, stride, pad, error_name
850 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100851
852 # Ensure new output type has correct qinfo
853 if error_name == ErrorIf.WrongInputType:
854 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000855 qinfo = [
856 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000857 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000858 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100859
860 # Invalidate Input/Output list for error if checks.
861 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000862 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100863 pCount, cCount = op["operands"]
864 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000865 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
866 self, error_name, input_list, output_list
867 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100868
Les Bell729b0352021-11-24 10:28:21 +0000869 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100870 self.ser,
871 validator_fcns,
872 error_name,
873 op=op,
874 input_shape=input.shape,
875 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000876 output_shape=result_tensor.shape,
877 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000878 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100879 kernel=kernel,
880 stride=stride,
881 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000882 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000883 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100884 input_list=input_list,
885 output_list=output_list,
886 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000887 ):
888 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700889
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000890 if qinfo is None:
891 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700892
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000893 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100894 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000895
896 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700897
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100898 compliance = self.tensorComplianceMetaData(
899 op, inputs[0].dtype, args_dict, result_tensor, error_name
900 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100901
902 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100903
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000904 def build_conv2d(
905 self,
906 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100907 inputs,
908 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000909 validator_fcns=None,
910 error_name=None,
911 qinfo=None,
912 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100913 assert len(inputs) == 3
914 ifm, filter, bias = inputs
915 accum_dtype = args_dict["acc_type"]
916 strides = args_dict["stride"]
917 padding = args_dict["pad"]
918 dilations = args_dict["dilation"]
919
Kevin Cheng550ccc52021-03-03 11:21:43 -0800920 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100921 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100922 self.ser,
923 self.rng,
924 ifm,
925 filter,
926 accum_dtype,
927 strides,
928 padding,
929 dilations,
930 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000931 )
932
933 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000934 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
935 DType.INT8,
936 DType.UINT8,
937 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000938 qinfo = [
939 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100940 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000941 ]
Les Bell0e027d42021-11-09 14:42:14 +0000942
943 # Invalidate Input/Output list for error_if checks.
944 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100945 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000946 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000947 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
948 self, error_name, input_list, output_list
949 )
Les Bell0e027d42021-11-09 14:42:14 +0000950
Les Bell729b0352021-11-24 10:28:21 +0000951 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000952 self.ser,
953 validator_fcns,
954 error_name,
955 op=op,
956 input_dtype=ifm.dtype,
957 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100958 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000959 qinfo=qinfo,
960 input_list=input_list,
961 num_operands=num_operands,
962 output_list=output_list,
963 pad=padding,
964 stride=strides,
965 dilation=dilations,
966 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100967 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100968 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000969 ):
970 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700971
Tai Lyd3797f02023-11-15 23:06:19 +0000972 # TODO - Test local_bound, for now set local bound attribute to False
973 local_bound = False
974
Eric Kunzee5e26762020-10-13 16:11:07 -0700975 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000976 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700977
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000978 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100979
980 compliance = self.tensorComplianceMetaData(
981 op, ifm.dtype, args_dict, result_tensor, error_name
982 )
983
984 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700985
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000986 def build_conv3d(
987 self,
988 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100989 inputs,
990 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000991 validator_fcns=None,
992 error_name=None,
993 qinfo=None,
994 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100995 assert len(inputs) == 3
996 ifm, filter, bias = inputs
997 accum_dtype = args_dict["acc_type"]
998 strides = args_dict["stride"]
999 padding = args_dict["pad"]
1000 dilations = args_dict["dilation"]
1001
Kevin Cheng1533b852021-09-01 12:51:58 -07001002 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +00001003 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +01001004 self.ser,
1005 self.rng,
1006 ifm,
1007 filter,
1008 accum_dtype,
1009 strides,
1010 padding,
1011 dilations,
1012 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001013 )
1014
1015 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001016 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1017 DType.INT8,
1018 DType.UINT8,
1019 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001020 qinfo = [
1021 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +00001022 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001023 ]
Les Bell0e027d42021-11-09 14:42:14 +00001024
1025 # Invalidate Input/Output list for error_if checks.
1026 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +00001027 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001028 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1030 self, error_name, input_list, output_list
1031 )
Les Bell0e027d42021-11-09 14:42:14 +00001032
Les Bell729b0352021-11-24 10:28:21 +00001033 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001034 self.ser,
1035 validator_fcns,
1036 error_name,
1037 op=op,
1038 input_dtype=ifm.dtype,
1039 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001040 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001041 qinfo=qinfo,
1042 input_list=input_list,
1043 num_operands=num_operands,
1044 output_list=output_list,
1045 pad=padding,
1046 stride=strides,
1047 dilation=dilations,
1048 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001049 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001050 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001051 ):
1052 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001053
Tai Lyd3797f02023-11-15 23:06:19 +00001054 # TODO - Test local_bound, for now set local bound attribute to False
1055 local_bound = False
1056
Kevin Cheng1533b852021-09-01 12:51:58 -07001057 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001058 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001059
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001060 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001061
1062 compliance = self.tensorComplianceMetaData(
1063 op, ifm.dtype, args_dict, result_tensor, error_name
1064 )
1065
1066 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001067
Kevin Cheng550ccc52021-03-03 11:21:43 -08001068 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001069 self,
1070 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001071 inputs,
1072 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001073 validator_fcns=None,
1074 error_name=None,
1075 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001076 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001077 assert len(inputs) == 3
1078 ifm, filter, bias = inputs
1079 accum_dtype = args_dict["acc_type"]
1080 strides = args_dict["stride"]
1081 out_pad = args_dict["pad"]
1082 output_shape = args_dict["out_shape"]
1083
TatWai Chong24594f52022-06-08 00:48:04 -07001084 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001085 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001086 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 )
Les Bell0e027d42021-11-09 14:42:14 +00001088
1089 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001090 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1091 DType.INT8,
1092 DType.UINT8,
1093 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001094 qinfo = [
1095 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001096 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001097 ]
Les Bell0e027d42021-11-09 14:42:14 +00001098
1099 # Invalidate Input/Output list for error_if checks.
1100 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001101 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001102 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001103 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1104 self, error_name, input_list, output_list
1105 )
Les Bell0e027d42021-11-09 14:42:14 +00001106
Les Bell729b0352021-11-24 10:28:21 +00001107 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001108 self.ser,
1109 validator_fcns,
1110 error_name,
1111 op=op,
1112 input_dtype=ifm.dtype,
1113 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001114 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001115 qinfo=qinfo,
1116 input_list=input_list,
1117 num_operands=num_operands,
1118 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001119 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001120 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001121 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001122 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001123 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001124 ):
1125 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001126
Tai Lyd3797f02023-11-15 23:06:19 +00001127 # TODO - Test local_bound, for now set local bound attribute to False
1128 local_bound = False
1129
Eric Kunzee5e26762020-10-13 16:11:07 -07001130 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001131 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001132 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001133 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001134
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001135 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001136
1137 compliance = self.tensorComplianceMetaData(
1138 op, ifm.dtype, args_dict, result_tensor, error_name
1139 )
1140
1141 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001142
Kevin Cheng550ccc52021-03-03 11:21:43 -08001143 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001144 self,
1145 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001146 inputs,
1147 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001148 validator_fcns=None,
1149 error_name=None,
1150 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001151 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001152 assert len(inputs) == 3
1153 ifm, filter, bias = inputs
1154 accum_dtype = args_dict["acc_type"]
1155 strides = args_dict["stride"]
1156 padding = args_dict["pad"]
1157 dilations = args_dict["dilation"]
1158
Jeremy Johnson4f931302024-01-04 17:05:24 +00001159 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001160 self.ser,
1161 self.rng,
1162 ifm,
1163 filter,
1164 accum_dtype,
1165 strides,
1166 padding,
1167 dilations,
1168 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001169 )
1170
1171 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001172 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1173 DType.INT8,
1174 DType.UINT8,
1175 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001176 qinfo = [
1177 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001178 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001179 ]
Les Bell0e027d42021-11-09 14:42:14 +00001180
1181 # Invalidate Input/Output list for error_if checks.
1182 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001183 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001184 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001185 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1186 self, error_name, input_list, output_list
1187 )
Les Bell0e027d42021-11-09 14:42:14 +00001188
Les Bell729b0352021-11-24 10:28:21 +00001189 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001190 self.ser,
1191 validator_fcns,
1192 error_name,
1193 op=op,
1194 input_dtype=ifm.dtype,
1195 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001196 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001197 qinfo=qinfo,
1198 input_list=input_list,
1199 num_operands=num_operands,
1200 output_list=output_list,
1201 pad=padding,
1202 stride=strides,
1203 dilation=dilations,
1204 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001205 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001206 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001207 ):
1208 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001209
Tai Lyd3797f02023-11-15 23:06:19 +00001210 # TODO - Test local_bound, for now set local bound attribute to False
1211 local_bound = False
1212
Eric Kunzee5e26762020-10-13 16:11:07 -07001213 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001214 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001215
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001216 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001217
1218 compliance = self.tensorComplianceMetaData(
1219 op, ifm.dtype, args_dict, result_tensor, error_name
1220 )
1221
1222 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001223
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001224 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001225 self,
1226 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001227 inputs,
1228 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001229 validator_fcns=None,
1230 error_name=None,
1231 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001232 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001233 assert len(inputs) == 3
1234 ifm, filter, bias = inputs
1235 accum_dtype = args_dict["acc_type"]
1236
1237 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001238 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001239 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001240
1241 # Invalidate Input/Output list for error if checks.
1242 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001243 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001244 pCount, cCount = op["operands"]
1245 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001246 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1247 self, error_name, input_list, output_list
1248 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001249
Les Bell729b0352021-11-24 10:28:21 +00001250 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001251 self.ser,
1252 validator_fcns,
1253 error_name,
1254 op=op,
1255 input_shape=ifm.shape,
1256 input_dtype=ifm.dtype,
1257 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001258 output_shape=result_tensor.shape,
1259 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001260 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001261 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001262 input_list=input_list,
1263 output_list=output_list,
1264 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001265 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001266 ):
1267 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001268
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001269 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001270 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001271
1272 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001273
1274 compliance = self.tensorComplianceMetaData(
1275 op, ifm.dtype, args_dict, result_tensor, error_name
1276 )
1277
1278 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001279
James Ward8b390432022-08-12 20:48:56 +01001280 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001281 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001282 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001283 assert len(inputs) == 2
1284 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001285 accum_dtype = args_dict["acc_type"]
1286 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001287 self.ser, self.rng, a, b, accum_dtype, error_name
1288 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001289
1290 # Invalidate Input/Output list for error if checks.
1291 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001292 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001293 pCount, cCount = op["operands"]
1294 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001295 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1296 self, error_name, input_list, output_list
1297 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001298
Les Bell729b0352021-11-24 10:28:21 +00001299 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001300 self.ser,
1301 validator_fcns,
1302 error_name,
1303 op=op,
1304 input_shape=a.shape,
1305 input_dtype=a.dtype,
1306 input2_shape=b.shape,
1307 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001308 output_shape=result_tensor.shape,
1309 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001311 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001312 input_list=input_list,
1313 output_list=output_list,
1314 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001315 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001316 ):
1317 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001318
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001319 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001320 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001321
1322 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001323
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001324 compliance = self.tensorComplianceMetaData(
1325 op, a.dtype, args_dict, result_tensor, error_name
1326 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001327
1328 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001329
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001330 def build_reduce(
1331 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1332 ):
1333 assert len(inputs) == 1
1334 a = inputs[0]
1335 axis = args_dict["axis"]
1336 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001337
1338 # Invalidate Input/Output list for error if checks.
1339 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001340 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001341 pCount, cCount = op["operands"]
1342 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001343 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1344 self, error_name, input_list, output_list
1345 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001346
Les Bell729b0352021-11-24 10:28:21 +00001347 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001348 self.ser,
1349 validator_fcns,
1350 error_name,
1351 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001352 axis=axis,
1353 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001354 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001356 output_dtype=result_tensor.dtype,
1357 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001358 input_list=input_list,
1359 output_list=output_list,
1360 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001361 ):
1362 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001363
1364 attr = ts.TosaSerializerAttribute()
1365 attr.AxisAttribute(axis)
1366
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001367 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001368
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001369 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1370 # Number of products - needed for compliance
1371 args_dict["n"] = a.shape[axis]
1372
1373 compliance = self.tensorComplianceMetaData(
1374 op, a.dtype, args_dict, result_tensor, error_name
1375 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001376
1377 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001378
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001379 def build_clamp(
1380 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1381 ):
1382 assert len(inputs) == 1
1383 a = inputs[0]
1384
1385 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001386
Jeremy Johnson18e26662021-07-22 16:15:29 +01001387 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001388
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001389 if error_name == ErrorIf.MaxSmallerMin:
1390 # Make sure the numbers are different to invoke this error
1391 while v[0] == v[1]:
1392 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1393 max_val = min(v)
1394 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001395 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396 max_val = max(v)
1397 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001398
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399 # Invalidate Input/Output list for error if checks.
1400 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001401 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 pCount, cCount = op["operands"]
1403 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1405 self, error_name, input_list, output_list
1406 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001407
Les Bell729b0352021-11-24 10:28:21 +00001408 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001409 self.ser,
1410 validator_fcns,
1411 error_name,
1412 op=op,
1413 max_val=max_val,
1414 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001415 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001416 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001418 output_dtype=result_tensor.dtype,
1419 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420 input_list=input_list,
1421 output_list=output_list,
1422 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001423 ):
1424 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001425
1426 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001427 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1428 if a.dtype == DType.FP16:
1429 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1430 min_val = min_val.astype(np.float32)
1431 max_val = max_val.astype(np.float32)
Tai Ly60dc48c2024-03-08 22:19:41 +00001432 min_val_as_bytes = struct.pack("<f", min_val)
1433 max_val_as_bytes = struct.pack("<f", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001434 elif a.dtype in (DType.INT8, DType.INT16):
Tai Ly60dc48c2024-03-08 22:19:41 +00001435 min_val_as_bytes = struct.pack("<i", min_val)
1436 max_val_as_bytes = struct.pack("<i", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001437 else:
1438 # to avoid internal error for incorrect input types
Tai Ly60dc48c2024-03-08 22:19:41 +00001439 min_val_as_bytes = struct.pack("<i", 0)
1440 max_val_as_bytes = struct.pack("<i", 0)
1441
1442 attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001444 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001445
1446 compliance = self.tensorComplianceMetaData(
1447 op, a.dtype, args_dict, result_tensor, error_name
1448 )
1449
1450 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001451
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001452 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1453 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001454 attr = ts.TosaSerializerAttribute()
1455
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001456 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001459 return result_tens
1460
1461 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001462 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1463 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001464
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001465 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001466 return result_tens
1467
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001468 def build_activation(
1469 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1470 ):
1471 assert len(inputs) == 1
1472 a = inputs[0]
1473
1474 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001475
1476 # Invalidate Input/Output list for error if checks.
1477 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001478 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001479 pCount, cCount = op["operands"]
1480 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001481 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1482 self, error_name, input_list, output_list
1483 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001484
Les Bell729b0352021-11-24 10:28:21 +00001485 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001486 self.ser,
1487 validator_fcns,
1488 error_name,
1489 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001490 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001491 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001492 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001493 output_dtype=result_tensor.dtype,
1494 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001495 input_list=input_list,
1496 output_list=output_list,
1497 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001498 ):
1499 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001500
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001501 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001502
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001503 compliance = self.tensorComplianceMetaData(
1504 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001505 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001506
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001507 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001508
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001509 def build_concat(
1510 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1511 ):
Won Jeon74342e52024-01-09 00:34:40 +00001512 if op["op"] == Op.CONCAT_SHAPE:
1513 axis = 0
1514 else:
1515 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001516 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001517 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001518
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001519 result_tensor = OutputShaper.concatOp(
1520 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001521 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001522
Matthew Haddon818ab902021-07-27 09:12:49 +01001523 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001524 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001525 input_tensor_names.append(tensor.name)
1526
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001527 # Invalidate Input/Output list for error if checks.
1528 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001529 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001530 pCount, cCount = op["operands"]
1531 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001532 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1533 self, error_name, input_list, output_list
1534 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001535
Les Bell729b0352021-11-24 10:28:21 +00001536 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001537 self.ser,
1538 validator_fcns,
1539 error_name,
1540 op=op,
1541 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001542 input_shape=inputs[0].shape,
1543 output_shape=result_tensor.shape,
1544 input_dtype=inputs[0].dtype,
1545 output_dtype=result_tensor.dtype,
1546 inputs=inputs,
1547 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001548 input_list=input_list,
1549 output_list=output_list,
1550 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001551 ):
1552 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001553
Won Jeon74342e52024-01-09 00:34:40 +00001554 if op["op"] == Op.CONCAT:
1555 attr = ts.TosaSerializerAttribute()
1556 attr.AxisAttribute(axis)
1557 else:
1558 assert op["op"] == Op.CONCAT_SHAPE
1559 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001560 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001561
1562 compliance = self.tensorComplianceMetaData(
1563 op, inputs[0].dtype, args_dict, result_tensor, error_name
1564 )
1565
1566 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001567
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001568 def build_pad(
1569 self,
1570 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001571 inputs,
1572 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001573 validator_fcns=None,
1574 error_name=None,
1575 qinfo=None,
1576 ):
Tai Lye095da72024-01-25 22:00:18 +00001577 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001578 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001579 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001580 padding = args_dict["pad"]
1581 pad_const_int = args_dict["pad_const_int"]
1582 pad_const_float = args_dict["pad_const_fp"]
1583
1584 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001585
Tai Ly60dc48c2024-03-08 22:19:41 +00001586 # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
1587 if gtu.dtypeIsFloat(a.dtype):
1588 pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
1589 else:
1590 pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
1591
Kevin Chengfe392ce2021-10-18 21:51:55 +00001592 attr = ts.TosaSerializerAttribute()
Tai Ly60dc48c2024-03-08 22:19:41 +00001593 attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
Eric Kunzee5e26762020-10-13 16:11:07 -07001594
Matthew Haddone807aae2021-10-11 18:12:58 +01001595 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001596 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001597 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001598 pCount, cCount = op["operands"]
1599 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001600 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1601 self, error_name, input_list, output_list
1602 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001603
Les Bell729b0352021-11-24 10:28:21 +00001604 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001605 self.ser,
1606 validator_fcns,
1607 error_name,
1608 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001609 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001610 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001611 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001612 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001613 pad=padding,
1614 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001615 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001616 input_list=input_list,
1617 output_list=output_list,
1618 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001619 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001620 ):
1621 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001622
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001623 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001624
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001625 compliance = self.tensorComplianceMetaData(
1626 op, a.dtype, args_dict, result_tensor, error_name
1627 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001628
1629 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001630
Won Jeona21b2e82023-08-10 10:33:01 +00001631 def build_dim(
1632 self,
1633 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001634 inputs,
1635 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001636 validator_fcns=None,
1637 error_name=None,
1638 qinfo=None,
1639 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001640 assert len(inputs) == 1
1641 a = inputs[0]
1642 axis = args_dict["axis"]
1643 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001644
1645 # Invalidate Input/Output list for error if checks.
1646 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001647 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001648 pCount, cCount = op["operands"]
1649 num_operands = pCount + cCount
1650 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1651 self, error_name, input_list, output_list
1652 )
1653
1654 if not TosaErrorValidator.evValidateErrorIfs(
1655 self.ser,
1656 validator_fcns,
1657 error_name,
1658 op=op,
1659 axis=axis,
1660 input_shape=a.shape,
1661 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001662 output_shape=result_tensor.shape,
1663 output_dtype=result_tensor.dtype,
1664 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001665 input_list=input_list,
1666 output_list=output_list,
1667 num_operands=num_operands,
1668 ):
1669 return None
1670
1671 attr = ts.TosaSerializerAttribute()
1672 attr.AxisAttribute(axis)
1673
1674 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001675 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001676
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001677 def build_reshape(
1678 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1679 ):
Tai Ly8690a082023-12-18 20:40:24 +00001680 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001681 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001682 shape = inputs[1]
1683 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001684 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001685 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001686 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001687
1688 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001689 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001690 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001691 pCount, cCount = op["operands"]
1692 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001693 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1694 self, error_name, input_list, output_list
1695 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001696
Les Bell729b0352021-11-24 10:28:21 +00001697 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001698 self.ser,
1699 validator_fcns,
1700 error_name,
1701 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001702 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001703 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001704 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001705 output_dtype=result_tensor.dtype,
1706 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001707 input_list=input_list,
1708 output_list=output_list,
1709 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001710 ):
1711 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001712
Tai Ly8690a082023-12-18 20:40:24 +00001713 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001714
1715 compliance = self.tensorComplianceMetaData(
1716 op, a.dtype, args_dict, result_tensor, error_name
1717 )
1718
1719 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001720
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001721 def build_reverse(
1722 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1723 ):
1724 assert len(inputs) == 1
1725 a = inputs[0]
1726 axis = args_dict["axis"]
1727 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001728
1729 # Invalidate Input/Output list for error if checks.
1730 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001731 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001732 pCount, cCount = op["operands"]
1733 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001734 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1735 self, error_name, input_list, output_list
1736 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001737
Les Bell729b0352021-11-24 10:28:21 +00001738 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001739 self.ser,
1740 validator_fcns,
1741 error_name,
1742 op=op,
1743 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001744 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001745 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001746 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001747 output_dtype=result_tensor.dtype,
1748 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001749 input_list=input_list,
1750 output_list=output_list,
1751 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001752 ):
1753 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001754
1755 attr = ts.TosaSerializerAttribute()
1756 attr.AxisAttribute(axis)
1757
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001758 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001759 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001760
evacha0198477222024-01-26 12:25:32 +00001761 def build_transpose(
1762 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1763 ):
1764 assert len(inputs) == 1
1765 a = inputs[0]
1766 perms = args_dict["perms"]
1767
1768 result_tensor = OutputShaper.transposeOp(
1769 self.ser, self.rng, a, perms, error_name
1770 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001771
Kevin Chengfe392ce2021-10-18 21:51:55 +00001772 attr = ts.TosaSerializerAttribute()
1773 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001774
Matthew Haddone807aae2021-10-11 18:12:58 +01001775 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001776 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001777 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001778 pCount, cCount = op["operands"]
1779 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001780 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1781 self, error_name, input_list, output_list
1782 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001783
Les Bell729b0352021-11-24 10:28:21 +00001784 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001785 self.ser,
1786 validator_fcns,
1787 error_name,
1788 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001789 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001790 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001791 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001792 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001793 output_dtype=result_tensor.dtype,
1794 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001795 input_list=input_list,
1796 output_list=output_list,
1797 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001798 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001799 ):
1800 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001801
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001802 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001803
1804 compliance = self.tensorComplianceMetaData(
1805 op, a.dtype, args_dict, result_tensor, error_name
1806 )
1807
1808 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001809
evacha017f7d4252024-01-24 12:08:09 +00001810 def build_slice(
1811 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1812 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001813 assert len(inputs) == 3
1814 a, start_var, size_var = inputs
1815 start_const = args_dict["start"]
1816 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001817
1818 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001819 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001820 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001821
1822 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001823 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001824 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001825 pCount, cCount = op["operands"]
1826 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1828 self, error_name, input_list, output_list
1829 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001830
Les Bell729b0352021-11-24 10:28:21 +00001831 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001832 self.ser,
1833 validator_fcns,
1834 error_name,
1835 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001836 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001837 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001838 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001839 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001840 start=start_const,
1841 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001842 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001843 input_list=input_list,
1844 output_list=output_list,
1845 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001846 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001847 ):
1848 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001849
Tai Ly8ead6c42024-02-14 22:35:44 +00001850 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001851
1852 compliance = self.tensorComplianceMetaData(
1853 op, a.dtype, args_dict, result_tensor, error_name
1854 )
1855
1856 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001857
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001858 def build_tile(
1859 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1860 ):
Tai Ly8690a082023-12-18 20:40:24 +00001861 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001862 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001863 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001864 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001865 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001866 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001867 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001868
1869 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001870 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001871 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001872 pCount, cCount = op["operands"]
1873 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001874 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1875 self, error_name, input_list, output_list
1876 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001877
Les Bell729b0352021-11-24 10:28:21 +00001878 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001879 self.ser,
1880 validator_fcns,
1881 error_name,
1882 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001883 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001884 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001885 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001886 output_dtype=result_tensor.dtype,
1887 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001888 input_list=input_list,
1889 output_list=output_list,
1890 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001891 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001892 ):
1893 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001894
Tai Ly8690a082023-12-18 20:40:24 +00001895 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001896
1897 compliance = self.tensorComplianceMetaData(
1898 op, a.dtype, args_dict, result_tensor, error_name
1899 )
1900
1901 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001902
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001903 def build_gather(
1904 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1905 ):
1906 assert len(inputs) == 2
1907 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001909 result_tensor = OutputShaper.gatherOp(
1910 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001911 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001912
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001913 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001914 input_list = [values.name, indices.name]
1915 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001916 pCount, cCount = op["operands"]
1917 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001918 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1919 self, error_name, input_list, output_list
1920 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001921
Les Bell729b0352021-11-24 10:28:21 +00001922 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001923 self.ser,
1924 validator_fcns,
1925 error_name,
1926 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001927 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001928 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001929 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001930 output_dtype=result_tensor.dtype,
1931 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001932 input_list=input_list,
1933 output_list=output_list,
1934 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001935 ):
1936 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001937
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001938 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001939
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001940 compliance = self.tensorComplianceMetaData(
1941 op, values.dtype, args_dict, result_tensor, error_name
1942 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001944 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001945
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001946 def build_scatter(
1947 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1948 ):
1949 assert len(inputs) == 3
1950 values_in, indices, input = inputs
1951 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001952 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001953 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001954
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001955 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001956 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001957 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001958 pCount, cCount = op["operands"]
1959 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001960 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1961 self, error_name, input_list, output_list
1962 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001963
Les Bell729b0352021-11-24 10:28:21 +00001964 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001965 self.ser,
1966 validator_fcns,
1967 error_name,
1968 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001969 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001970 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001971 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001972 output_dtype=result_tensor.dtype,
1973 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001974 input_list=input_list,
1975 output_list=output_list,
1976 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001977 ):
1978 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001979
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001980 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001981
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001982 compliance = self.tensorComplianceMetaData(
1983 op, values_in.dtype, args_dict, result_tensor, error_name
1984 )
1985
1986 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001987
Kevin Cheng550ccc52021-03-03 11:21:43 -08001988 def build_resize(
1989 self,
1990 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001991 inputs,
1992 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001993 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001994 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001995 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001996 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001997 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001998 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001999 scale_input = inputs[1]
2000 offset_input = inputs[2]
2001 border_input = inputs[3]
2002
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002003 mode = args_dict["mode"]
2004 scale = args_dict["scale"]
2005 offset = args_dict["offset"]
2006 border = args_dict["border"]
2007 output_dtype = args_dict["output_dtype"]
2008
2009 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002010 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002011 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002012 input,
2013 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002014 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002015 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002016 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002017 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002018 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002019 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002020 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002021
Matthew Haddon848efb42021-09-09 12:30:53 +01002022 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002023 input_list = [
2024 input.name,
2025 scale_input.name,
2026 offset_input.name,
2027 border_input.name,
2028 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002029 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002030 pCount, cCount = op["operands"]
2031 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002032 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2033 self, error_name, input_list, output_list
2034 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002035
Les Bell729b0352021-11-24 10:28:21 +00002036 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002037 self.ser,
2038 validator_fcns,
2039 error_name,
2040 op=op,
2041 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002042 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002043 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002044 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002045 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002046 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002047 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002048 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002049 input_list=input_list,
2050 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002051 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002052 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002053 ):
2054 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002055
Eric Kunzee5e26762020-10-13 16:11:07 -07002056 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002057 # write empty scale/offset/border into ResizeAttribute
2058 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002059 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002060
2061 compliance = self.tensorComplianceMetaData(
2062 op, input.dtype, args_dict, result_tensor, error_name
2063 )
2064
2065 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002066
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002067 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2068 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2069 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002070 self.ser.addOperator(
2071 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2072 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002073 return result_tens
2074
evacha0198477222024-01-26 12:25:32 +00002075 def build_const(
2076 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2077 ):
2078 assert len(inputs) == 1
2079 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002080 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002081
2082 compliance = self.tensorComplianceMetaData(
2083 op, val.dtype, args_dict, val, error_name
2084 )
2085
2086 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002087
2088 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002089 def build_cast(
2090 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2091 ):
2092 assert len(inputs) == 1
2093 val = inputs[0]
2094 out_dtype = args_dict["out_type"]
2095
2096 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002097 self.ser, self.rng, val, out_dtype, error_name
2098 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002099
2100 # Invalidate Input/Output list for error if checks.
2101 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002102 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002103 pCount, cCount = op["operands"]
2104 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002105 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2106 self, error_name, input_list, output_list
2107 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002108
Les Bell729b0352021-11-24 10:28:21 +00002109 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002110 self.ser,
2111 validator_fcns,
2112 error_name,
2113 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002114 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002115 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002116 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002117 output_dtype=result_tensor.dtype,
2118 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002119 input_list=input_list,
2120 output_list=output_list,
2121 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002122 ):
2123 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002124
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002125 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002126
2127 compliance = self.tensorComplianceMetaData(
2128 op, val.dtype, args_dict, result_tensor, error_name
2129 )
2130
2131 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002132
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 def build_rescale(
2134 self,
2135 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002136 inputs,
2137 args_dict,
2138 validator_fcns=None,
2139 error_name=None,
2140 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002141 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002142 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002143 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002144 multiplier_val = inputs[1]
2145 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002146 out_dtype = args_dict["output_dtype"]
2147 scale32 = args_dict["scale"]
2148 double_round = args_dict["double_round"]
2149 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002150 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002151 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002152
2153 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002154 self.ser, self.rng, val, out_dtype, error_name
2155 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002156
2157 if per_channel:
2158 nc = val.shape[-1]
2159 else:
2160 nc = 1
2161
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002162 in_type_width = gtu.dtypeWidth(val.dtype)
2163 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002164
Tai Ly8690a082023-12-18 20:40:24 +00002165 input_unsigned = False
2166 output_unsigned = False
2167
Kevin Cheng3a478572021-01-22 17:21:02 -08002168 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002169 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002170 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002171 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002172 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002173 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002174 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002175 elif error_name in [
2176 ErrorIf.InputZeroPointNotZero,
2177 ErrorIf.U16InputZeroPointNotValid,
2178 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002179 input_zp = self.randInt(-128, 128)
2180 if input_zp == 0:
2181 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002182 in_type_width += 1
2183 elif val.dtype == DType.UINT16:
2184 # Must come after ErrorIf.U16InputZeroPointNotValid check
2185 input_zp = self.rng.choice([0, 32768])
2186 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002187 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002188 else:
2189 input_zp = 0
2190
Kevin Cheng3a478572021-01-22 17:21:02 -08002191 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002192 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002193 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002194 elif out_dtype == DType.UINT8:
2195 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002196 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002197 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002198 elif error_name in [
2199 ErrorIf.OutputZeroPointNotZero,
2200 ErrorIf.U16OutputZeroPointNotValid,
2201 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002202 output_zp = self.randInt(-128, 128)
2203 if output_zp == 0:
2204 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002205 out_type_width += 1
2206 elif out_dtype == DType.UINT16:
2207 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2208 output_zp = self.rng.choice([0, 32768])
2209 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002210 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002211 else:
2212 output_zp = 0
2213
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002214 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2215 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002216
2217 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002218 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2219 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002220
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002221 logger.debug(
2222 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2223 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002224 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002225 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002226 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002227 assert val.placeholderFilename
2228 values = np.load(
2229 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2230 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002231 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2232 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2233 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002234 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2235 # Check we can safely convert to the expected dtype
2236 assert (
2237 val_adj.all() >= np.iinfo(values.dtype).min
2238 and val_adj.all() <= np.iinfo(values.dtype).max
2239 )
2240
2241 # Force casting to output datatype
2242 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2243
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002244 if not np.all(np.array_equal(values, val_adj)):
2245 # Values changed so overwrite file with new values
2246 np.save(
2247 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2248 val_adj,
2249 False,
2250 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002251
Matthew Haddonc2025212021-10-08 21:21:05 +01002252 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002253 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002254 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002255 pCount, cCount = op["operands"]
2256 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002257 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2258 self, error_name, input_list, output_list
2259 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002260
2261 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002262 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002263 self.ser,
2264 validator_fcns,
2265 error_name,
2266 op=op,
2267 input_dtype=val.dtype,
2268 output_dtype=out_dtype,
2269 input_shape=val.shape,
2270 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002271 scale32=scale32,
2272 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002273 input_list=input_list,
2274 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002275 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002276 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002277 ):
2278 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002279
Eric Kunzee5e26762020-10-13 16:11:07 -07002280 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002281 attr.RescaleAttribute(
2282 input_zp,
2283 output_zp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002284 scale32,
2285 double_round,
2286 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002287 input_unsigned,
2288 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002289 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002290
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002291 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002292
2293 compliance = self.tensorComplianceMetaData(
2294 op, val.dtype, args_dict, result_tensor, error_name
2295 )
2296
2297 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002298
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002299 def _get_condition_tensor(self, op, cond, error_name):
2300 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002301 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002302 else:
2303 cond_type = DType.BOOL
2304 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2305 choice = self.rng.choice([1, 2])
2306 if choice == 1:
2307 cond_shape = [2]
2308 else:
2309 cond_shape = [1, 2]
2310 else:
2311 # Must be of size 1 (rank 0)
2312 cond_shape = []
2313 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2314 return cond_tens
2315
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002316 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002317 self,
2318 op,
2319 inputs,
2320 args_dict,
2321 validator_fcns=None,
2322 error_name=None,
2323 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002324 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002325 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002326 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002327 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002328 assert len(inputs) == 2
2329 then_tens, else_tens = inputs
2330
2331 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002332
2333 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002334 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002335
2336 # Make then/else tensors
2337 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002338
Jeremy Johnson587cc842024-02-08 11:45:44 +00002339 dtype = DType.INT32
2340
Matthew Haddon630c17c2021-10-14 15:05:41 +01002341 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002342 if error_name in [
2343 ErrorIf.CondIfOutputListThenGraphMismatch,
2344 ErrorIf.CondIfOutputListElseGraphMismatch,
2345 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002346 incorrect_shape = deepcopy(then_tens.shape)
2347 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002348 incorrect_shape[i] += (
2349 self.rng.choice([-3, -2, 2, 3])
2350 if incorrect_shape[i] > 3
2351 else self.rng.choice([1, 2, 4])
2352 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002353 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2354
Jeremy Johnson18e26662021-07-22 16:15:29 +01002355 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2356 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002357
2358 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002359 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002360
2361 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002362 then_block = "THEN_BLOCK"
2363 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002364 attr = ts.TosaSerializerAttribute()
2365 attr.CondIfAttribute(then_block, else_block)
2366
2367 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002368 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002369
Jerry Ge9e94af82022-10-27 09:57:00 -07002370 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002371 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002372 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002373 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002374 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002375 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002376 self.ser.addOutputTensor(then_tens)
2377
Jerry Ge9e94af82022-10-27 09:57:00 -07002378 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002379 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002380 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002381 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002382 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002383 self.ser.addOutputTensor(else_tens)
2384
Les Bell729b0352021-11-24 10:28:21 +00002385 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002386 self.ser,
2387 validator_fcns,
2388 error_name,
2389 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002390 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002391 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002392 ):
2393 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002394
Jeremy Johnson587cc842024-02-08 11:45:44 +00002395 compliance = self.tensorComplianceMetaData(
2396 op, dtype, args_dict, result_tensor, error_name
2397 )
2398
2399 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002400
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002401 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002402 self,
2403 op,
2404 inputs,
2405 args_dict,
2406 validator_fcns=None,
2407 error_name=None,
2408 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002409 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002410 # For cond_if with a binary op in the then/else blocks, take a and b and
2411 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002412 assert len(inputs) == 2
2413 a, b = inputs
2414
2415 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002416
2417 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002418 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002419
Jeremy Johnson587cc842024-02-08 11:45:44 +00002420 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002421
2422 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002423 then_block = "THEN_BLOCK"
2424 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002425 attr = ts.TosaSerializerAttribute()
2426 attr.CondIfAttribute(then_block, else_block)
2427
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002428 if error_name in [
2429 ErrorIf.CondIfInputListThenGraphMismatch,
2430 ErrorIf.CondIfInputListElseGraphMismatch,
2431 ErrorIf.CondIfOutputListElseGraphMismatch,
2432 ErrorIf.CondIfOutputListThenGraphMismatch,
2433 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002434 incorrect_shape = a.shape.copy()
2435 for i in range(len(incorrect_shape)):
2436 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2437 incorrect_block_input = deepcopy(a)
2438 incorrect_block_input.shape = incorrect_shape
2439
Eric Kunzee5e26762020-10-13 16:11:07 -07002440 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002441 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002442 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002443 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002444
James Ward24dbc422022-10-19 12:20:31 +01002445 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002446 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002447 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002448 then_op, else_op = (
2449 self.TOSA_OP_LIST["logical_right_shift"],
2450 self.TOSA_OP_LIST["logical_left_shift"],
2451 )
Les Bell6040b4d2021-10-11 12:50:31 +01002452 else:
2453 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002454
Jeremy Johnson587cc842024-02-08 11:45:44 +00002455 # Determine the element-wise binary operation that compliance will need to
2456 # check the results of
2457 compliance_op = then_op if cond else else_op
2458
2459 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002460 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002461 if (
2462 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2463 and block == then_block
2464 ) or (
2465 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2466 and block == else_block
2467 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002468 self.ser.addInputTensor(incorrect_block_input)
2469 self.ser.addInputTensor(b)
2470 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002471 elif (
2472 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2473 and block == then_block
2474 ) or (
2475 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2476 and block == else_block
2477 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002478 self.ser.addInputTensor(a)
2479 self.ser.addInputTensor(b)
2480 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2481 else:
2482 self.ser.addInputTensor(a)
2483 self.ser.addInputTensor(b)
2484 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002485 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002486
Les Bell729b0352021-11-24 10:28:21 +00002487 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002488 self.ser,
2489 validator_fcns,
2490 error_name,
2491 op=op,
2492 a=a,
2493 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002494 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002495 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002496 ):
2497 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002498
Jeremy Johnson587cc842024-02-08 11:45:44 +00002499 compliance = self.tensorComplianceMetaData(
2500 compliance_op, a.dtype, args_dict, result_tensor, error_name
2501 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002502
Jeremy Johnson587cc842024-02-08 11:45:44 +00002503 return TosaTestGen.BuildInfo(result_tensor, compliance)
2504
2505 def build_while_loop(
2506 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2507 ):
2508 assert len(inputs) == 1
2509 a = inputs[0]
2510 iter_val = args_dict["iterations"]
2511
Kevin Cheng550ccc52021-03-03 11:21:43 -08002512 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002513
Kevin Cheng550ccc52021-03-03 11:21:43 -08002514 cond_block = "COND_BLOCK"
2515 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002516
2517 attr = ts.TosaSerializerAttribute()
2518 attr.WhileLoopAttribute(cond_block, body_block)
2519
2520 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002521 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002522 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002523 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002524
2525 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002526 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2527 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002528 if error_name == ErrorIf.InputListOutputListMismatch:
2529 incorrect_acc = deepcopy(acc)
2530 for i in range(len(incorrect_acc.shape)):
2531 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2532 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2533 else:
2534 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002535
2536 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002537 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002538 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 [iter.name, a.name, acc.name],
2540 [iter_out.name, a_out.name, acc_out.name],
2541 attr,
2542 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002543 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002544
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002545 if error_name in [
2546 ErrorIf.InputListCondGraphMismatch,
2547 ErrorIf.InputListBodyGraphInputMismatch,
2548 ErrorIf.InputListBodyGraphOutputMismatch,
2549 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002550 incorrect_iter = deepcopy(iter)
2551 for i in range(len(incorrect_iter.shape)):
2552 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2553 if len(incorrect_iter.shape) == 0:
2554 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2555
2556 incorrect_acc = deepcopy(acc)
2557 for i in range(len(incorrect_acc.shape)):
2558 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2559
Eric Kunzee5e26762020-10-13 16:11:07 -07002560 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002561 self.ser.addBasicBlock(cond_block)
2562
Matthew Haddon630c17c2021-10-14 15:05:41 +01002563 if error_name == ErrorIf.InputListCondGraphMismatch:
2564 self.ser.addInputTensor(incorrect_iter)
2565 self.ser.addInputTensor(a)
2566 self.ser.addInputTensor(incorrect_acc)
2567 else:
2568 self.ser.addInputTensor(iter)
2569 self.ser.addInputTensor(a)
2570 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002571 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002572
2573 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002574 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002575 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002576 cond_type = DType.BOOL
2577 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2578 choice = self.rng.choice([1, 2])
2579 if choice == 1:
2580 cond_shape = [3]
2581 else:
2582 cond_shape = [1, 2]
2583 else:
2584 cond_shape = []
2585 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002586
Kevin Cheng550ccc52021-03-03 11:21:43 -08002587 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002588
2589 # BODY block (input: a, acc, iter, output: a, acc, iter)
2590 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002591 self.ser.addBasicBlock(body_block)
2592
Matthew Haddon630c17c2021-10-14 15:05:41 +01002593 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2594 self.ser.addInputTensor(incorrect_iter)
2595 self.ser.addInputTensor(a)
2596 self.ser.addInputTensor(incorrect_acc)
2597 else:
2598 self.ser.addInputTensor(iter)
2599 self.ser.addInputTensor(a)
2600 self.ser.addInputTensor(acc)
2601
Kevin Cheng550ccc52021-03-03 11:21:43 -08002602 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002603
2604 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002605 iter_body_out = self.ser.addIntermediate(
2606 incorrect_iter.shape, incorrect_iter.dtype
2607 )
2608 acc_body_out = self.ser.addIntermediate(
2609 incorrect_acc.shape, incorrect_acc.dtype
2610 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002611 else:
2612 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2613 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2614
Eric Kunzee5e26762020-10-13 16:11:07 -07002615 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2616 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2617 self.ser.addOutputTensor(iter_body_out)
2618 self.ser.addOutputTensor(a)
2619 self.ser.addOutputTensor(acc_body_out)
2620
Les Bell729b0352021-11-24 10:28:21 +00002621 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002622 self.ser,
2623 validator_fcns,
2624 error_name,
2625 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002626 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002627 ):
2628 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002629
Jeremy Johnson587cc842024-02-08 11:45:44 +00002630 compliance = self.tensorComplianceMetaData(
2631 op, a.dtype, args_dict, acc_out, error_name
2632 )
2633
2634 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002635
Luke Hutton57287132023-02-06 14:54:18 +00002636 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002637 self,
2638 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002639 inputs,
2640 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002641 validator_fcns=None,
2642 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002643 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002644 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002645 assert len(inputs) == 2
2646 val1, val2 = inputs
2647 inverse = args_dict["inverse"]
2648
Luke Hutton57287132023-02-06 14:54:18 +00002649 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2650
2651 input_names = [val1.name, val2.name]
2652 pCount, cCount = op["operands"]
2653 num_operands = pCount + cCount
2654
2655 output_names = [res.name for res in results]
2656 output_shapes = [res.shape for res in results]
2657 output_dtypes = [res.dtype for res in results]
2658
2659 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2660 self, error_name, input_names, output_names
2661 )
2662
2663 if not TosaErrorValidator.evValidateErrorIfs(
2664 self.ser,
2665 validator_fcns,
2666 error_name,
2667 op=op,
2668 inverse=inverse,
2669 input1=val1,
2670 input2=val2,
2671 input_shape=val1.shape,
2672 input_dtype=val1.dtype,
2673 output_shape=output_shapes,
2674 output_dtype=output_dtypes,
2675 result_tensors=results,
2676 input_list=input_names,
2677 output_list=output_names,
2678 num_operands=num_operands,
2679 ):
2680 return None
2681
Tai Lyd3797f02023-11-15 23:06:19 +00002682 # TODO - Test local_bound, for now set local bound attribute to False
2683 local_bound = False
2684
Luke Hutton57287132023-02-06 14:54:18 +00002685 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002686 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002687
2688 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002689
2690 compliance = []
2691 for res in results:
2692 compliance.append(
2693 self.tensorComplianceMetaData(
2694 op, val1.dtype, args_dict, res, error_name
2695 )
2696 )
2697
2698 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002699
Tai Lyd3797f02023-11-15 23:06:19 +00002700 def build_rfft2d(
2701 self,
2702 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002703 inputs,
2704 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002705 validator_fcns=None,
2706 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002707 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002708 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002709 assert len(inputs) == 1
2710 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002711 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2712
2713 input_names = [val.name]
2714 pCount, cCount = op["operands"]
2715 num_operands = pCount + cCount
2716
2717 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002718 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002719 output_dtypes = [res.dtype for res in results]
2720
2721 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2722 self, error_name, input_names, output_names
2723 )
2724
2725 if not TosaErrorValidator.evValidateErrorIfs(
2726 self.ser,
2727 validator_fcns,
2728 error_name,
2729 op=op,
2730 input_shape=val.shape,
2731 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002732 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002733 output_dtype=output_dtypes,
2734 result_tensors=results,
2735 input_list=input_names,
2736 output_list=output_names,
2737 num_operands=num_operands,
2738 ):
2739 return None
2740
Tai Lyd3797f02023-11-15 23:06:19 +00002741 # TODO - Test local_bound, for now set local bound attribute to False
2742 local_bound = False
2743
2744 attr = ts.TosaSerializerAttribute()
2745 attr.RFFTAttribute(local_bound)
2746
2747 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002748
2749 compliance = []
2750 for res in results:
2751 compliance.append(
2752 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2753 )
2754
2755 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002756
Won Jeon74342e52024-01-09 00:34:40 +00002757 def build_shape_op(
2758 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2759 ):
2760 assert len(inputs) == 2
2761 a, b = inputs
2762
2763 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2764
2765 # Invalidate Input/Output list for error if checks.
2766 input_list = [a.name, b.name]
2767 output_list = [result_tensor.name]
2768 pCount, cCount = op["operands"]
2769 num_operands = pCount + cCount
2770 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2771 self, error_name, input_list, output_list
2772 )
2773
2774 if not TosaErrorValidator.evValidateErrorIfs(
2775 self.ser,
2776 validator_fcns,
2777 error_name,
2778 op=op,
2779 input1=a,
2780 input2=b,
2781 input_shape=a.shape,
2782 input_dtype=a.dtype,
2783 output_shape=result_tensor.shape,
2784 output_dtype=result_tensor.dtype,
2785 result_tensors=[result_tensor],
2786 input_list=input_list,
2787 output_list=output_list,
2788 num_operands=num_operands,
2789 ):
2790 return None
2791
2792 self.ser.addOperator(
2793 op["op"],
2794 input_list,
2795 output_list,
2796 )
2797 compliance = self.tensorComplianceMetaData(
2798 op, a.dtype, args_dict, result_tensor, error_name
2799 )
2800
2801 return TosaTestGen.BuildInfo(result_tensor, compliance)
2802
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002803 def create_filter_lists(
2804 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2805 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002806 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2807 default_test_rank_range = range(1, 5)
2808 if not shapeFilter:
2809 shapeFilter = [None]
2810
2811 # Calculate the filters based on what is requested and what the operator allows
2812 rmin, rmax = op["rank"]
2813 if rankFilter is not None:
2814 cleanRankFilter = []
2815 # Ensure rankFilter values are allowed by operator
2816 for rank in rankFilter:
2817 if rank >= rmin and rank <= rmax:
2818 cleanRankFilter.append(rank)
2819 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002820 # Ensure default behaviour is bounded by default range or by operator,
2821 # whichever is the smaller range of ranks.
2822 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002823 cleanRankFilter = (
2824 opRankRange
2825 if len(opRankRange) <= len(default_test_rank_range)
2826 else default_test_rank_range
2827 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002828 else:
2829 cleanRankFilter = range(rmin, rmax + 1)
2830
2831 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002832
Matthew Haddon1c00b712021-10-01 15:51:03 +01002833 if dtypeFilter is not None:
2834 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002835 # Create list of operator dtypes filtered by requested dtypes
2836 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002837 if dtype in dtypeFilter or (
2838 isinstance(dtype, list) and dtype[0] in dtypeFilter
2839 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002840 cleanDtypeFilter.append(dtype)
2841 else:
2842 cleanDtypeFilter = dtypes
2843
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002844 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002845 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002846 "shapeFilter": shapeFilter,
2847 "rankFilter": cleanRankFilter,
2848 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002849 }
2850 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002851 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002852 if validator is not None:
2853 validator_info = validator(check=False, op=op)
2854 else:
2855 return None
2856
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002857 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002858
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002859 # Set parameters as required
2860 if error_arguments["rank"] is not None:
2861 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002862 else:
2863 rankFilter = cleanRankFilter
2864
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002865 if error_arguments["dtype"] is not None:
2866 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002867 else:
2868 dtypeFilter = cleanDtypeFilter
2869
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 if error_arguments["shape"] is not None:
2871 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002872 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002873 shapeFilter = shapeFilter[
2874 :2
2875 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002876
2877 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002878 "shapeFilter": shapeFilter,
2879 "rankFilter": rankFilter,
2880 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002881 }
2882 return filterDict
2883
Kevin Cheng550ccc52021-03-03 11:21:43 -08002884 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002885 self,
2886 opName,
2887 shapeFilter=[None],
2888 rankFilter=None,
2889 dtypeFilter=None,
2890 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002891 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002892
2893 try:
2894 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002895 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002896 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002897
2898 # Initialize a new random number generator
2899 self.rng = np.random.default_rng(self.random_seed)
2900
Jeremy Johnson1271c442023-09-05 11:39:26 +01002901 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002902
Eric Kunzee5e26762020-10-13 16:11:07 -07002903 # Test list consists of a tuple of:
2904 # (opName, testNameStr, dtype, shapeList, argumentsList)
2905 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002906 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002907 error_if_validators = op["error_if_validators"]
2908 else:
2909 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002910
Matthew Haddon1c00b712021-10-01 15:51:03 +01002911 for validator in error_if_validators:
2912 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002913 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002914 else:
2915 error_name = None
2916
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002917 filterDict = self.create_filter_lists(
2918 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2919 )
2920 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002921 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002922 cleanRankFilter = filterDict["rankFilter"]
2923 cleanDtypeFilter = filterDict["dtypeFilter"]
2924 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002925 logger.debug(
2926 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2927 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002928
2929 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002930 for t in cleanDtypeFilter:
2931 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002932 # Filter out by rank
2933 if shape is not None and len(shape) != r:
2934 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002935 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002936 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002937
Matthew Haddon74567092021-07-16 15:38:20 +01002938 shapeStr = self.shapeStr(shapeList[0])
2939 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002940
Matthew Haddon74567092021-07-16 15:38:20 +01002941 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2942 argList = []
2943 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002944 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002945 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002946 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002947
Matthew Haddon74567092021-07-16 15:38:20 +01002948 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002949 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002950 if argStr:
2951 testStr = "{}_{}_{}_{}".format(
2952 opName, shapeStr, typeStr, argStr
2953 )
2954 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002955 testStr = "{}_{}_{}".format(
2956 opName, shapeStr, typeStr
2957 )
2958 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002959 if argStr:
2960 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2961 opName, error_name, shapeStr, typeStr, argStr
2962 )
2963 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002964 testStr = "{}_ERRORIF_{}_{}_{}".format(
2965 opName, error_name, shapeStr, typeStr
2966 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002967
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002968 testList.append(
2969 (opName, testStr, t, error_name, shapeList, args)
2970 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002971
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002972 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002973 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2974 if "invalid_test_validators" in op:
2975 invalid_test_validators = op["invalid_test_validators"]
2976 clean_testList = []
2977 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002978 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002979 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002980 if validator_fcn(
2981 opName=test[0],
2982 input_dtype=test[2],
2983 shapeList=test[4],
2984 args=test[5],
2985 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002986 remove_test = True
2987 if not remove_test:
2988 clean_testList.append(test)
2989 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002990
2991 return testList
2992
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002993 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002994 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002995 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002996 try:
2997 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002998 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002999 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003000
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003001 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01003002
Eric Kunzee5e26762020-10-13 16:11:07 -07003003 # Create a serializer
3004 self.createSerializer(opName, testStr)
3005
Jeremy Johnson1271c442023-09-05 11:39:26 +01003006 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003007 if "error_if_validators" in op:
3008 error_if_validators = op["error_if_validators"]
3009 else:
3010 error_if_validators = None
3011
Kevin Cheng550ccc52021-03-03 11:21:43 -08003012 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003013 num_operands = pCount + cCount
3014
3015 if isinstance(dtype_or_dtypeList, list):
3016 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003017 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003018 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003019 else:
3020 dtypeList = [dtype_or_dtypeList] * (num_operands)
3021
Won Jeon74342e52024-01-09 00:34:40 +00003022 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003023 assert (
3024 len(shapeList) == num_operands
3025 ), "shapeList length {} must match number of operands {}".format(
3026 len(shapeList), num_operands
3027 )
3028 assert (
3029 len(dtypeList) == num_operands
3030 ), "dtypeList length {} must match number of operands {}".format(
3031 len(dtypeList), num_operands
3032 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003033
3034 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003035 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003036 except KeyError:
3037 qgen = None
3038
3039 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003040
Matthew Haddon1c00b712021-10-01 15:51:03 +01003041 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003042 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003043 else:
3044 qinfo = None
3045
Jeremy Johnson1271c442023-09-05 11:39:26 +01003046 # Extra meta data for the desc.json
3047 tensMeta = {}
3048
Jeremy Johnson587cc842024-02-08 11:45:44 +00003049 # Check we are using the new interface with an argsDict dictionary
3050 assert isinstance(
3051 argsDict, dict
3052 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003053
Jeremy Johnson587cc842024-02-08 11:45:44 +00003054 # New interface with args info in dictionary
3055 assert "dg_type" in argsDict
3056 tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
3057 if tvgInfo.dataGenDict:
3058 tensMeta["data_gen"] = tvgInfo.dataGenDict
3059 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003060
Jeremy Johnson587cc842024-02-08 11:45:44 +00003061 result = build_fcn(
3062 self,
3063 op,
3064 tens,
3065 argsDict,
3066 validator_fcns=error_if_validators,
3067 error_name=error_name,
3068 qinfo=qinfo,
3069 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003070
Jeremy Johnson1271c442023-09-05 11:39:26 +01003071 if result:
Les Bell729b0352021-11-24 10:28:21 +00003072 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003073 if isinstance(result, TosaTestGen.BuildInfo):
3074 # Add the compliance meta data (if any)
3075 compliance = result.getComplianceInfo()
3076 if compliance:
3077 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003078 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003079 else:
3080 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003081 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003082
Eric Kunzee5e26762020-10-13 16:11:07 -07003083 def createDynamicOpLists(self):
3084
Jeremy Johnson00423432022-09-12 17:27:37 +01003085 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3086 # Already created these lists (can occur when class is initialized more than once)
3087 return
3088
Eric Kunzee5e26762020-10-13 16:11:07 -07003089 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003090 if not self.args.level8k:
3091 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3092 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3093 else:
3094 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3095 KERNELS_2D = [[1, bigK], [bigK, 2]]
3096 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003097
Kevin Cheng1533b852021-09-01 12:51:58 -07003098 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003099 testName = "conv2d_{}x{}".format(k[0], k[1])
3100 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3101 self.TOSA_OP_LIST[testName]["filter"] = k
3102 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003103 self.TOSA_OP_LIST[testName]["real_name"] = "conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003104
Kevin Cheng550ccc52021-03-03 11:21:43 -08003105 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3106 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3107 "depthwise_conv2d_TEMPLATE"
3108 ].copy()
3109 self.TOSA_OP_LIST[testName]["filter"] = k
3110 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003111 self.TOSA_OP_LIST[testName]["real_name"] = "depthwise_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003112
Kevin Cheng550ccc52021-03-03 11:21:43 -08003113 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3114 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3115 "transpose_conv2d_TEMPLATE"
3116 ].copy()
3117 self.TOSA_OP_LIST[testName]["filter"] = k
3118 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003119 self.TOSA_OP_LIST[testName]["real_name"] = "transpose_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003120
Kevin Cheng1533b852021-09-01 12:51:58 -07003121 for k in KERNELS_3D:
3122 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3123 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3124 self.TOSA_OP_LIST[testName]["filter"] = k
3125 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003126 self.TOSA_OP_LIST[testName]["real_name"] = "conv3d"
Kevin Cheng1533b852021-09-01 12:51:58 -07003127
Eric Kunzee5e26762020-10-13 16:11:07 -07003128 # Delete any templates after having created any dynamic ops
3129 # This is a two-pass operation because it's bad practice to delete
3130 # keys from dictionaries while iterating
3131 keyList = []
3132 for k in self.TOSA_OP_LIST:
3133 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003134 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003135 keyList.append(k)
3136 continue
3137 except KeyError:
3138 pass
3139
3140 for k in keyList:
3141 del self.TOSA_OP_LIST[k]
3142
3143 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003144 """Fill in default fields for ops if they aren't already specified.
3145 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003146 for op in self.TOSA_OP_LIST:
3147
3148 # Required fields
3149 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003150 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003151 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003152 raise Exception(
3153 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3154 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003155
3156 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003157 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003158 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003159 raise Exception(
3160 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3161 op
3162 )
3163 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003164
3165 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003166 _ = self.TOSA_OP_LIST[op]["types"]
3167 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003168 raise Exception(
3169 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3170 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003171
3172 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003173 _ = self.TOSA_OP_LIST[op]["op"]
3174 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003175 raise Exception(
3176 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3177 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003178
3179 # Put in default rank range, if missing
3180 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003181 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003182 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003183 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003184
3185 # Tensor operator list
3186 # 'op': op name
3187 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003188 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3189 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003190 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3191 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003192 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003193
Kevin Cheng550ccc52021-03-03 11:21:43 -08003194 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003195 TYPE_INT_FP = [
3196 DType.INT8,
3197 DType.INT16,
3198 DType.INT32,
3199 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003200 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003201 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003202 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003203
Kevin Cheng550ccc52021-03-03 11:21:43 -08003204 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003205 TYPE_FI32 = [
3206 DType.FP32,
3207 DType.FP16,
3208 DType.BF16,
3209 DType.INT32,
3210 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003211 TYPE_FIB = [
3212 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003213 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003214 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003215 DType.INT8,
3216 DType.INT16,
3217 DType.INT32,
3218 DType.BOOL,
3219 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003220 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003221
Won Jeon2c34b462024-02-06 18:37:00 +00003222 TYPE_NARROW_INT_FP = [
3223 DType.INT8,
3224 DType.INT16,
3225 DType.FP16,
3226 DType.BF16,
3227 DType.FP32,
3228 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003229
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003230 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003231 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003232 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003233 [DType.INT8, DType.INT8, DType.INT32],
3234 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003235 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003236 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003237 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003238 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003239 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3240 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003241 ]
3242
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003243 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003244
3245 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003246 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003247 "argmax": {
3248 "op": Op.ARGMAX,
3249 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003250 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003251 "build_fcn": (
3252 build_argmax,
3253 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003254 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 TosaArgGen.agAxis,
3256 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003257 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003258 "error_if_validators": (
3259 TosaErrorValidator.evAxisSmallerZero,
3260 TosaErrorValidator.evAxisLargerRank,
3261 TosaErrorValidator.evArgmaxOutputRankMismatch,
3262 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3263 TosaErrorValidator.evWrongRank,
3264 TosaErrorValidator.evWrongInputType,
3265 TosaErrorValidator.evWrongOutputType,
3266 TosaErrorValidator.evWrongInputList,
3267 TosaErrorValidator.evWrongOutputList,
3268 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003269 "data_gen": {
3270 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3271 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003272 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003273 "avg_pool2d": {
3274 "op": Op.AVG_POOL2D,
3275 "operands": (1, 0),
3276 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003277 "build_fcn": (
3278 build_pool2d,
3279 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003280 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003281 TosaArgGen.agPooling,
3282 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003284 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003285 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003286 "error_if_validators": (
3287 TosaErrorValidator.evKernelSmallerOne,
3288 TosaErrorValidator.evStrideSmallerOne,
3289 TosaErrorValidator.evPadSmallerZero,
3290 TosaErrorValidator.evWrongRank,
3291 TosaErrorValidator.evWrongInputType,
3292 TosaErrorValidator.evWrongOutputType,
3293 TosaErrorValidator.evWrongInputList,
3294 TosaErrorValidator.evWrongOutputList,
3295 TosaErrorValidator.evInputZeroPointNotZero,
3296 TosaErrorValidator.evOutputZeroPointNotZero,
3297 TosaErrorValidator.evPadLargerEqualKernel,
3298 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003299 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003300 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003301 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003302 "data_gen": {
3303 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3304 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003305 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003306 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003307 "conv2d_TEMPLATE": {
3308 "op": Op.CONV2D,
3309 "operands": (1, 2),
3310 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003311 "build_fcn": (
3312 build_conv2d,
3313 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003314 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003315 TosaArgGen.agConv,
3316 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003317 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003318 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003319 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3320 "error_if_validators": (
3321 TosaErrorValidator.evWrongInputType,
3322 TosaErrorValidator.evWrongOutputType,
3323 TosaErrorValidator.evWrongInputList,
3324 TosaErrorValidator.evWrongOutputList,
3325 TosaErrorValidator.evInputZeroPointNotZero,
3326 TosaErrorValidator.evWeightZeroPointNotZero,
3327 TosaErrorValidator.evPadSmallerZero,
3328 TosaErrorValidator.evStrideSmallerOne,
3329 TosaErrorValidator.evDilationSmallerOne,
3330 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003331 TosaErrorValidator.evConvOutputShapeMismatch,
3332 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003333 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003334 "data_gen": {
3335 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3336 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003337 "template": True,
3338 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003339 # Templated operator. Filled in by createDynamicOpLists
3340 "conv3d_TEMPLATE": {
3341 "op": Op.CONV3D,
3342 "operands": (1, 2),
3343 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003344 "build_fcn": (
3345 build_conv3d,
3346 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003347 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003348 TosaArgGen.agConv,
3349 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003350 "qgen": TosaQuantGen.qgConv,
3351 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003352 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3353 "error_if_validators": (
3354 TosaErrorValidator.evWrongInputType,
3355 TosaErrorValidator.evWrongOutputType,
3356 TosaErrorValidator.evWrongInputList,
3357 TosaErrorValidator.evWrongOutputList,
3358 TosaErrorValidator.evInputZeroPointNotZero,
3359 TosaErrorValidator.evWeightZeroPointNotZero,
3360 TosaErrorValidator.evPadSmallerZero,
3361 TosaErrorValidator.evStrideSmallerOne,
3362 TosaErrorValidator.evDilationSmallerOne,
3363 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003364 TosaErrorValidator.evConvOutputShapeMismatch,
3365 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003366 ),
evacha0147ab1762024-01-29 13:23:23 +00003367 "data_gen": {
3368 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3369 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003370 "template": True,
3371 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003372 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003373 "depthwise_conv2d_TEMPLATE": {
3374 "op": Op.DEPTHWISE_CONV2D,
3375 "operands": (1, 2),
3376 "filter": [1, 1],
3377 "rank": (4, 4),
3378 "build_fcn": (
3379 build_depthwise_conv2d,
3380 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003381 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003382 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003383 ),
3384 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003385 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003386 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3387 "error_if_validators": (
3388 TosaErrorValidator.evWrongInputType,
3389 TosaErrorValidator.evWrongOutputType,
3390 TosaErrorValidator.evWrongInputList,
3391 TosaErrorValidator.evWrongOutputList,
3392 TosaErrorValidator.evInputZeroPointNotZero,
3393 TosaErrorValidator.evWeightZeroPointNotZero,
3394 TosaErrorValidator.evPadSmallerZero,
3395 TosaErrorValidator.evStrideSmallerOne,
3396 TosaErrorValidator.evDilationSmallerOne,
3397 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003398 TosaErrorValidator.evConvOutputShapeMismatch,
3399 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003400 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003401 "data_gen": {
3402 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3403 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003404 "template": True,
3405 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 "fully_connected": {
3407 "op": Op.FULLY_CONNECTED,
3408 "operands": (1, 2),
3409 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003410 "build_fcn": (
3411 build_fully_connected,
3412 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003413 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003414 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003415 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003416 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003417 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003418 "error_if_validators": (
3419 TosaErrorValidator.evInputZeroPointNotZero,
3420 TosaErrorValidator.evWeightZeroPointNotZero,
3421 TosaErrorValidator.evWrongRank,
3422 TosaErrorValidator.evWrongInputType,
3423 TosaErrorValidator.evWrongOutputType,
3424 TosaErrorValidator.evWrongInputList,
3425 TosaErrorValidator.evWrongOutputList,
3426 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003427 "data_gen": {
3428 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3429 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003430 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003431 "matmul": {
3432 "op": Op.MATMUL,
3433 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003434 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003435 "build_fcn": (
3436 build_matmul,
3437 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003438 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003439 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003440 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003441 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003442 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003443 "error_if_validators": (
3444 TosaErrorValidator.evInputZeroPointNotZero,
3445 TosaErrorValidator.evWrongRank,
3446 TosaErrorValidator.evWrongInputType,
3447 TosaErrorValidator.evWrongOutputType,
3448 TosaErrorValidator.evWrongInputList,
3449 TosaErrorValidator.evWrongOutputList,
3450 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003451 "data_gen": {
3452 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003453 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003454 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 "max_pool2d": {
3456 "op": Op.MAX_POOL2D,
3457 "operands": (1, 0),
3458 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003459 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003460 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003461 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003462 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003463 TosaArgGen.agPooling,
3464 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003465 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003466 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003467 "error_if_validators": (
3468 TosaErrorValidator.evKernelSmallerOne,
3469 TosaErrorValidator.evStrideSmallerOne,
3470 TosaErrorValidator.evPadSmallerZero,
3471 TosaErrorValidator.evWrongRank,
3472 TosaErrorValidator.evWrongInputType,
3473 TosaErrorValidator.evWrongOutputType,
3474 TosaErrorValidator.evWrongInputList,
3475 TosaErrorValidator.evWrongOutputList,
3476 TosaErrorValidator.evPadLargerEqualKernel,
3477 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003478 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003479 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003480 "data_gen": {
3481 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3482 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003483 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003484 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003485 "transpose_conv2d_TEMPLATE": {
3486 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003487 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003488 "rank": (4, 4),
3489 "build_fcn": (
3490 build_transpose_conv2d,
3491 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003492 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003493 TosaArgGen.agTransposeConv2D,
3494 ),
3495 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003496 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003497 "invalid_test_validators": (
3498 TosaInvalidValidator.ivHeightWidthInvalid,
3499 TosaInvalidValidator.ivNonPositiveOutputShape,
3500 ),
3501 "error_if_validators": (
3502 TosaErrorValidator.evWrongInputType,
3503 TosaErrorValidator.evWrongOutputType,
3504 TosaErrorValidator.evWrongInputList,
3505 TosaErrorValidator.evWrongOutputList,
3506 TosaErrorValidator.evInputZeroPointNotZero,
3507 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003508 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003509 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003510 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003511 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003512 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003513 "data_gen": {
3514 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3515 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003516 "template": True,
3517 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003518 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003519 "clamp": {
3520 "op": Op.CLAMP,
3521 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003522 "build_fcn": (
3523 build_clamp,
3524 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003525 TosaTensorValuesGen.tvgLazyGenDefault,
3526 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003527 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003528 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003529 "error_if_validators": (
3530 TosaErrorValidator.evMaxSmallerMin,
3531 TosaErrorValidator.evWrongInputType,
3532 TosaErrorValidator.evWrongOutputType,
3533 TosaErrorValidator.evWrongInputList,
3534 TosaErrorValidator.evWrongOutputList,
3535 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003536 "data_gen": {
3537 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3538 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003539 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003540 "sigmoid": {
3541 "op": Op.SIGMOID,
3542 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003543 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003544 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003545 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003546 TosaTensorValuesGen.tvgLazyGenDefault,
3547 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003548 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003549 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003550 "error_if_validators": (
3551 TosaErrorValidator.evWrongInputType,
3552 TosaErrorValidator.evWrongOutputType,
3553 TosaErrorValidator.evWrongInputList,
3554 TosaErrorValidator.evWrongOutputList,
3555 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003556 "data_gen": {
3557 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3558 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003559 },
3560 "tanh": {
3561 "op": Op.TANH,
3562 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003563 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003564 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003565 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003566 TosaTensorValuesGen.tvgLazyGenDefault,
3567 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003568 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003569 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003570 "error_if_validators": (
3571 TosaErrorValidator.evWrongInputType,
3572 TosaErrorValidator.evWrongOutputType,
3573 TosaErrorValidator.evWrongInputList,
3574 TosaErrorValidator.evWrongOutputList,
3575 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003576 "data_gen": {
3577 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3578 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003579 "compliance": {
3580 "abs_error_lower_bound": 0.5,
3581 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003582 },
Won Jeon78155c62023-06-10 00:20:04 +00003583 "erf": {
3584 "op": Op.ERF,
3585 "operands": (1, 0),
3586 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003587 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003588 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003589 TosaTensorValuesGen.tvgLazyGenDefault,
3590 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003591 ),
3592 "types": TYPE_FP,
3593 "error_if_validators": (
3594 TosaErrorValidator.evWrongInputType,
3595 TosaErrorValidator.evWrongOutputType,
3596 TosaErrorValidator.evWrongInputList,
3597 TosaErrorValidator.evWrongOutputList,
3598 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003599 "data_gen": {
3600 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3601 },
3602 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003603 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 # Elementwise Binary Operators
3605 "add": {
3606 "op": Op.ADD,
3607 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003608 "build_fcn": (
3609 build_binary_broadcast,
3610 TosaTensorGen.tgBroadcastFuzz,
3611 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003612 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003613 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003614 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003615 "error_if_validators": (
3616 TosaErrorValidator.evRankMismatch,
3617 TosaErrorValidator.evWrongInputType,
3618 TosaErrorValidator.evWrongOutputType,
3619 TosaErrorValidator.evWrongInputList,
3620 TosaErrorValidator.evWrongOutputList,
3621 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003622 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003623 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003624 "data_gen": {
3625 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3626 },
3627 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003629 "arithmetic_right_shift": {
3630 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3631 "operands": (2, 0),
3632 "build_fcn": (
3633 build_arithmetic_right_shift,
3634 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003635 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 TosaArgGen.agArithmeticRightShift,
3637 ),
3638 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003639 "error_if_validators": (
3640 TosaErrorValidator.evRankMismatch,
3641 TosaErrorValidator.evWrongInputType,
3642 TosaErrorValidator.evWrongOutputType,
3643 TosaErrorValidator.evWrongInputList,
3644 TosaErrorValidator.evWrongOutputList,
3645 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003646 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003647 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003648 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003649 "bitwise_and": {
3650 "op": Op.BITWISE_AND,
3651 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003652 "build_fcn": (
3653 build_binary_broadcast,
3654 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003655 TosaTensorValuesGen.tvgLazyGenDefault,
3656 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003657 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003658 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003659 "error_if_validators": (
3660 TosaErrorValidator.evRankMismatch,
3661 TosaErrorValidator.evWrongInputType,
3662 TosaErrorValidator.evWrongOutputType,
3663 TosaErrorValidator.evWrongInputList,
3664 TosaErrorValidator.evWrongOutputList,
3665 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003666 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003667 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 "bitwise_or": {
3670 "op": Op.BITWISE_OR,
3671 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 "build_fcn": (
3673 build_binary_broadcast,
3674 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003675 TosaTensorValuesGen.tvgLazyGenDefault,
3676 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 "error_if_validators": (
3680 TosaErrorValidator.evRankMismatch,
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003686 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003687 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003688 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 "bitwise_xor": {
3690 "op": Op.BITWISE_XOR,
3691 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003692 "build_fcn": (
3693 build_binary_broadcast,
3694 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003695 TosaTensorValuesGen.tvgLazyGenDefault,
3696 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003698 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003699 "error_if_validators": (
3700 TosaErrorValidator.evRankMismatch,
3701 TosaErrorValidator.evWrongInputType,
3702 TosaErrorValidator.evWrongOutputType,
3703 TosaErrorValidator.evWrongInputList,
3704 TosaErrorValidator.evWrongOutputList,
3705 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003706 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003707 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003708 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003709 "intdiv": {
3710 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003711 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003712 "build_fcn": (
3713 build_binary_broadcast,
3714 TosaTensorGen.tgBroadcastFuzz,
3715 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003716 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003717 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003718 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003719 "error_if_validators": (
3720 TosaErrorValidator.evRankMismatch,
3721 TosaErrorValidator.evWrongInputType,
3722 TosaErrorValidator.evWrongOutputType,
3723 TosaErrorValidator.evWrongInputList,
3724 TosaErrorValidator.evWrongOutputList,
3725 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003726 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003727 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003728 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003729 "logical_and": {
3730 "op": Op.LOGICAL_AND,
3731 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 "build_fcn": (
3733 build_binary_broadcast,
3734 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003735 TosaTensorValuesGen.tvgLazyGenDefault,
3736 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003737 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003739 "error_if_validators": (
3740 TosaErrorValidator.evRankMismatch,
3741 TosaErrorValidator.evWrongInputType,
3742 TosaErrorValidator.evWrongOutputType,
3743 TosaErrorValidator.evWrongInputList,
3744 TosaErrorValidator.evWrongOutputList,
3745 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003746 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003747 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003749 "logical_left_shift": {
3750 "op": Op.LOGICAL_LEFT_SHIFT,
3751 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 "build_fcn": (
3753 build_binary_broadcast,
3754 TosaTensorGen.tgBroadcastFuzz,
3755 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003756 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003757 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 "error_if_validators": (
3760 TosaErrorValidator.evRankMismatch,
3761 TosaErrorValidator.evWrongInputType,
3762 TosaErrorValidator.evWrongOutputType,
3763 TosaErrorValidator.evWrongInputList,
3764 TosaErrorValidator.evWrongOutputList,
3765 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003766 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003767 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003768 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003769 "logical_right_shift": {
3770 "op": Op.LOGICAL_RIGHT_SHIFT,
3771 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 "build_fcn": (
3773 build_binary_broadcast,
3774 TosaTensorGen.tgBroadcastFuzz,
3775 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003776 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003777 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003778 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 "error_if_validators": (
3780 TosaErrorValidator.evRankMismatch,
3781 TosaErrorValidator.evWrongInputType,
3782 TosaErrorValidator.evWrongOutputType,
3783 TosaErrorValidator.evWrongInputList,
3784 TosaErrorValidator.evWrongOutputList,
3785 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003786 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003787 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003788 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003789 "logical_or": {
3790 "op": Op.LOGICAL_OR,
3791 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003792 "build_fcn": (
3793 build_binary_broadcast,
3794 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003795 TosaTensorValuesGen.tvgLazyGenDefault,
3796 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003797 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003798 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003799 "error_if_validators": (
3800 TosaErrorValidator.evRankMismatch,
3801 TosaErrorValidator.evWrongInputType,
3802 TosaErrorValidator.evWrongOutputType,
3803 TosaErrorValidator.evWrongInputList,
3804 TosaErrorValidator.evWrongOutputList,
3805 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003806 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003807 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003809 "logical_xor": {
3810 "op": Op.LOGICAL_XOR,
3811 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003812 "build_fcn": (
3813 build_binary_broadcast,
3814 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003815 TosaTensorValuesGen.tvgLazyGenDefault,
3816 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 "error_if_validators": (
3820 TosaErrorValidator.evRankMismatch,
3821 TosaErrorValidator.evWrongInputType,
3822 TosaErrorValidator.evWrongOutputType,
3823 TosaErrorValidator.evWrongInputList,
3824 TosaErrorValidator.evWrongOutputList,
3825 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003826 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003827 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003829 "maximum": {
3830 "op": Op.MAXIMUM,
3831 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003832 "build_fcn": (
3833 build_binary_broadcast,
3834 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003835 TosaTensorValuesGen.tvgLazyGenDefault,
3836 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003837 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003838 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 "error_if_validators": (
3840 TosaErrorValidator.evRankMismatch,
3841 TosaErrorValidator.evWrongInputType,
3842 TosaErrorValidator.evWrongOutputType,
3843 TosaErrorValidator.evWrongInputList,
3844 TosaErrorValidator.evWrongOutputList,
3845 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003846 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003847 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003848 "data_gen": {
3849 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3850 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003851 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003852 "minimum": {
3853 "op": Op.MINIMUM,
3854 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003855 "build_fcn": (
3856 build_binary_broadcast,
3857 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003858 TosaTensorValuesGen.tvgLazyGenDefault,
3859 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003860 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003861 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003862 "error_if_validators": (
3863 TosaErrorValidator.evRankMismatch,
3864 TosaErrorValidator.evWrongInputType,
3865 TosaErrorValidator.evWrongOutputType,
3866 TosaErrorValidator.evWrongInputList,
3867 TosaErrorValidator.evWrongOutputList,
3868 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003869 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003870 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003871 "data_gen": {
3872 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3873 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003874 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003875 "mul": {
3876 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003877 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003878 "build_fcn": (
3879 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003880 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003881 TosaTensorValuesGen.tvgMul,
3882 TosaArgGen.agMul,
3883 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003884 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003885 "error_if_validators": (
3886 TosaErrorValidator.evWrongInputType,
3887 TosaErrorValidator.evWrongOutputType,
3888 TosaErrorValidator.evWrongInputList,
3889 TosaErrorValidator.evWrongOutputList,
3890 TosaErrorValidator.evRankMismatch,
3891 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003892 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003893 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003894 "data_gen": {
3895 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3896 },
3897 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003899 "pow": {
3900 "op": Op.POW,
3901 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003902 "build_fcn": (
3903 build_binary_broadcast,
3904 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003905 TosaTensorValuesGen.tvgPow,
3906 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003907 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003908 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003909 "error_if_validators": (
3910 TosaErrorValidator.evRankMismatch,
3911 TosaErrorValidator.evWrongInputType,
3912 TosaErrorValidator.evWrongOutputType,
3913 TosaErrorValidator.evWrongInputList,
3914 TosaErrorValidator.evWrongOutputList,
3915 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003916 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003917 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003918 "data_gen": {
3919 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3920 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003922 "sub": {
3923 "op": Op.SUB,
3924 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003925 "build_fcn": (
3926 build_binary_broadcast,
3927 TosaTensorGen.tgBroadcastFuzz,
3928 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003929 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003930 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003931 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003932 "error_if_validators": (
3933 TosaErrorValidator.evRankMismatch,
3934 TosaErrorValidator.evWrongInputType,
3935 TosaErrorValidator.evWrongOutputType,
3936 TosaErrorValidator.evWrongInputList,
3937 TosaErrorValidator.evWrongOutputList,
3938 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003939 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003940 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003941 "data_gen": {
3942 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3943 },
3944 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003945 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 "table": {
3947 "op": Op.TABLE,
3948 # Use the automatic generation functions to create the input array
3949 # but create the table tensor in the build function, as it may be
3950 # a different type from the input
3951 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003952 "build_fcn": (
3953 build_table,
3954 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003955 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003956 TosaArgGen.agTable,
3957 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003958 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003959 "error_if_validators": (
3960 TosaErrorValidator.evWrongInputType,
3961 TosaErrorValidator.evWrongOutputType,
3962 TosaErrorValidator.evWrongInputList,
3963 TosaErrorValidator.evWrongOutputList,
3964 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003965 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003966 # Elementwise Unary operators
3967 "abs": {
3968 "op": Op.ABS,
3969 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003970 "build_fcn": (
3971 build_unary,
3972 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003973 TosaTensorValuesGen.tvgLazyGenDefault,
3974 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003975 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003977 "error_if_validators": (
3978 TosaErrorValidator.evWrongInputType,
3979 TosaErrorValidator.evWrongOutputType,
3980 TosaErrorValidator.evWrongInputList,
3981 TosaErrorValidator.evWrongOutputList,
3982 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003983 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00003984 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003985 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003987 "bitwise_not": {
3988 "op": Op.BITWISE_NOT,
3989 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003990 "build_fcn": (
3991 build_unary,
3992 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003993 TosaTensorValuesGen.tvgLazyGenDefault,
3994 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003995 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003996 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003997 "error_if_validators": (
3998 TosaErrorValidator.evWrongInputType,
3999 TosaErrorValidator.evWrongOutputType,
4000 TosaErrorValidator.evWrongInputList,
4001 TosaErrorValidator.evWrongOutputList,
4002 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004003 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004004 "ceil": {
4005 "op": Op.CEIL,
4006 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004007 "build_fcn": (
4008 build_unary,
4009 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004010 TosaTensorValuesGen.tvgLazyGenDefault,
4011 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004012 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004013 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004014 "error_if_validators": (
4015 TosaErrorValidator.evWrongInputType,
4016 TosaErrorValidator.evWrongOutputType,
4017 TosaErrorValidator.evWrongInputList,
4018 TosaErrorValidator.evWrongOutputList,
4019 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004020 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004021 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004022 },
4023 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004024 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004025 "clz": {
4026 "op": Op.CLZ,
4027 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004028 "build_fcn": (
4029 build_unary,
4030 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004031 TosaTensorValuesGen.tvgLazyGenDefault,
4032 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004033 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004034 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004035 "error_if_validators": (
4036 TosaErrorValidator.evWrongInputType,
4037 TosaErrorValidator.evWrongOutputType,
4038 TosaErrorValidator.evWrongInputList,
4039 TosaErrorValidator.evWrongOutputList,
4040 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004041 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004042 "cos": {
4043 "op": Op.COS,
4044 "operands": (1, 0),
4045 "build_fcn": (
4046 build_unary,
4047 TosaTensorGen.tgBasic,
4048 TosaTensorValuesGen.tvgLazyGenDefault,
4049 TosaArgGen.agNone,
4050 ),
4051 "types": TYPE_FP,
4052 "error_if_validators": (
4053 TosaErrorValidator.evWrongInputType,
4054 TosaErrorValidator.evWrongOutputType,
4055 TosaErrorValidator.evWrongInputList,
4056 TosaErrorValidator.evWrongOutputList,
4057 ),
4058 "data_gen": {
4059 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4060 },
4061 "compliance": {"abs_error_normal_divisor": 2},
4062 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004063 "exp": {
4064 "op": Op.EXP,
4065 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004066 "build_fcn": (
4067 build_unary,
4068 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004069 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004070 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004071 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004073 "error_if_validators": (
4074 TosaErrorValidator.evWrongInputType,
4075 TosaErrorValidator.evWrongOutputType,
4076 TosaErrorValidator.evWrongInputList,
4077 TosaErrorValidator.evWrongOutputList,
4078 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004079 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004080 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004081 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004083 "floor": {
4084 "op": Op.FLOOR,
4085 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004086 "build_fcn": (
4087 build_unary,
4088 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004089 TosaTensorValuesGen.tvgLazyGenDefault,
4090 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004091 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004092 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004093 "error_if_validators": (
4094 TosaErrorValidator.evWrongInputType,
4095 TosaErrorValidator.evWrongOutputType,
4096 TosaErrorValidator.evWrongInputList,
4097 TosaErrorValidator.evWrongOutputList,
4098 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004099 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004100 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004101 },
4102 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004103 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004104 "log": {
4105 "op": Op.LOG,
4106 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004107 "build_fcn": (
4108 build_unary,
4109 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004110 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004111 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004112 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004113 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004114 "error_if_validators": (
4115 TosaErrorValidator.evWrongInputType,
4116 TosaErrorValidator.evWrongOutputType,
4117 TosaErrorValidator.evWrongInputList,
4118 TosaErrorValidator.evWrongOutputList,
4119 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004120 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004121 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004122 },
4123 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004124 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004125 "logical_not": {
4126 "op": Op.LOGICAL_NOT,
4127 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004128 "build_fcn": (
4129 build_unary,
4130 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004131 TosaTensorValuesGen.tvgLazyGenDefault,
4132 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004133 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004134 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004135 "error_if_validators": (
4136 TosaErrorValidator.evWrongInputType,
4137 TosaErrorValidator.evWrongOutputType,
4138 TosaErrorValidator.evWrongInputList,
4139 TosaErrorValidator.evWrongOutputList,
4140 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004141 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004142 "negate": {
4143 "op": Op.NEGATE,
4144 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004145 "build_fcn": (
4146 build_unary,
4147 TosaTensorGen.tgBasic,
4148 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004149 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004150 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004151 "qgen": TosaQuantGen.qgUnary,
4152 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004153 "error_if_validators": (
4154 TosaErrorValidator.evInputZeroPointNotZero,
4155 TosaErrorValidator.evOutputZeroPointNotZero,
4156 TosaErrorValidator.evWrongInputType,
4157 TosaErrorValidator.evWrongOutputType,
4158 TosaErrorValidator.evWrongInputList,
4159 TosaErrorValidator.evWrongOutputList,
4160 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004161 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004162 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004163 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004165 "reciprocal": {
4166 "op": Op.RECIPROCAL,
4167 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004168 "build_fcn": (
4169 build_unary,
4170 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004171 TosaTensorValuesGen.tvgLazyGenDefault,
4172 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004173 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004174 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004175 "error_if_validators": (
4176 TosaErrorValidator.evWrongInputType,
4177 TosaErrorValidator.evWrongOutputType,
4178 TosaErrorValidator.evWrongInputList,
4179 TosaErrorValidator.evWrongOutputList,
4180 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004181 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004182 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004183 },
4184 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004185 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004186 "rsqrt": {
4187 "op": Op.RSQRT,
4188 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004189 "build_fcn": (
4190 build_unary,
4191 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004192 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004193 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004194 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004195 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004196 "error_if_validators": (
4197 TosaErrorValidator.evWrongInputType,
4198 TosaErrorValidator.evWrongOutputType,
4199 TosaErrorValidator.evWrongInputList,
4200 TosaErrorValidator.evWrongOutputList,
4201 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004202 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004203 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004204 },
4205 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004206 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004207 "sin": {
4208 "op": Op.SIN,
4209 "operands": (1, 0),
4210 "build_fcn": (
4211 build_unary,
4212 TosaTensorGen.tgBasic,
4213 TosaTensorValuesGen.tvgLazyGenDefault,
4214 TosaArgGen.agNone,
4215 ),
4216 "types": TYPE_FP,
4217 "error_if_validators": (
4218 TosaErrorValidator.evWrongInputType,
4219 TosaErrorValidator.evWrongOutputType,
4220 TosaErrorValidator.evWrongInputList,
4221 TosaErrorValidator.evWrongOutputList,
4222 ),
4223 "data_gen": {
4224 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4225 },
4226 "compliance": {"abs_error_normal_divisor": 2},
4227 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004228 # Elementwise Ternary operators
4229 "select": {
4230 "op": Op.SELECT,
4231 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004232 "build_fcn": (
4233 build_select,
4234 TosaTensorGen.tgBroadcastFuzz,
4235 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004236 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004237 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004238 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004239 "error_if_validators": (
4240 TosaErrorValidator.evRankMismatch,
4241 TosaErrorValidator.evWrongInputType,
4242 TosaErrorValidator.evWrongOutputType,
4243 TosaErrorValidator.evWrongInputList,
4244 TosaErrorValidator.evWrongOutputList,
4245 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004246 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004247 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004248 "data_gen": {
4249 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4250 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004251 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004252 # Comparison operators
4253 "equal": {
4254 "op": Op.EQUAL,
4255 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004256 "build_fcn": (
4257 build_comparison,
4258 TosaTensorGen.tgBroadcastFuzz,
4259 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004260 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004261 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004262 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004263 "error_if_validators": (
4264 TosaErrorValidator.evRankMismatch,
4265 TosaErrorValidator.evWrongInputType,
4266 TosaErrorValidator.evWrongOutputType,
4267 TosaErrorValidator.evWrongInputList,
4268 TosaErrorValidator.evWrongOutputList,
4269 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004270 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004271 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004272 "data_gen": {
4273 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4274 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004275 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004276 "greater_equal": {
4277 "op": Op.GREATER_EQUAL,
4278 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004279 "build_fcn": (
4280 build_comparison,
4281 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004282 TosaTensorValuesGen.tvgLazyGenDefault,
4283 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004284 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004285 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004286 "error_if_validators": (
4287 TosaErrorValidator.evRankMismatch,
4288 TosaErrorValidator.evWrongInputType,
4289 TosaErrorValidator.evWrongOutputType,
4290 TosaErrorValidator.evWrongInputList,
4291 TosaErrorValidator.evWrongOutputList,
4292 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004293 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004294 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004295 "data_gen": {
4296 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4297 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004298 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004299 "greater": {
4300 "op": Op.GREATER,
4301 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004302 "build_fcn": (
4303 build_comparison,
4304 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004305 TosaTensorValuesGen.tvgLazyGenDefault,
4306 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004307 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004308 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004309 "error_if_validators": (
4310 TosaErrorValidator.evRankMismatch,
4311 TosaErrorValidator.evWrongInputType,
4312 TosaErrorValidator.evWrongOutputType,
4313 TosaErrorValidator.evWrongInputList,
4314 TosaErrorValidator.evWrongOutputList,
4315 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004316 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004317 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004318 "data_gen": {
4319 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4320 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004321 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004322 # Reduction operators
4323 "reduce_all": {
4324 "op": Op.REDUCE_ALL,
4325 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004326 "build_fcn": (
4327 build_reduce,
4328 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004329 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004330 TosaArgGen.agAxis,
4331 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004332 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004333 "error_if_validators": (
4334 TosaErrorValidator.evAxisLargerRank,
4335 TosaErrorValidator.evAxisSmallerZero,
4336 TosaErrorValidator.evShapeOfAxisNotOne,
4337 TosaErrorValidator.evWrongInputType,
4338 TosaErrorValidator.evWrongOutputType,
4339 TosaErrorValidator.evWrongRank,
4340 TosaErrorValidator.evWrongInputList,
4341 TosaErrorValidator.evWrongOutputList,
4342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004344 "reduce_any": {
4345 "op": Op.REDUCE_ANY,
4346 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004347 "build_fcn": (
4348 build_reduce,
4349 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004350 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004351 TosaArgGen.agAxis,
4352 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004353 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004354 "error_if_validators": (
4355 TosaErrorValidator.evAxisLargerRank,
4356 TosaErrorValidator.evAxisSmallerZero,
4357 TosaErrorValidator.evShapeOfAxisNotOne,
4358 TosaErrorValidator.evWrongInputType,
4359 TosaErrorValidator.evWrongOutputType,
4360 TosaErrorValidator.evWrongRank,
4361 TosaErrorValidator.evWrongInputList,
4362 TosaErrorValidator.evWrongOutputList,
4363 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004364 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004365 "reduce_max": {
4366 "op": Op.REDUCE_MAX,
4367 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004368 "build_fcn": (
4369 build_reduce,
4370 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004371 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004372 TosaArgGen.agAxis,
4373 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004374 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004375 "error_if_validators": (
4376 TosaErrorValidator.evAxisLargerRank,
4377 TosaErrorValidator.evAxisSmallerZero,
4378 TosaErrorValidator.evShapeOfAxisNotOne,
4379 TosaErrorValidator.evWrongInputType,
4380 TosaErrorValidator.evWrongOutputType,
4381 TosaErrorValidator.evWrongRank,
4382 TosaErrorValidator.evWrongInputList,
4383 TosaErrorValidator.evWrongOutputList,
4384 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004385 "data_gen": {
4386 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4387 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004388 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004389 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004390 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004391 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004392 "build_fcn": (
4393 build_reduce,
4394 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004395 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004396 TosaArgGen.agAxis,
4397 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004398 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004399 "error_if_validators": (
4400 TosaErrorValidator.evAxisLargerRank,
4401 TosaErrorValidator.evAxisSmallerZero,
4402 TosaErrorValidator.evShapeOfAxisNotOne,
4403 TosaErrorValidator.evWrongInputType,
4404 TosaErrorValidator.evWrongOutputType,
4405 TosaErrorValidator.evWrongRank,
4406 TosaErrorValidator.evWrongInputList,
4407 TosaErrorValidator.evWrongOutputList,
4408 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004409 "data_gen": {
4410 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4411 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004412 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004413 "reduce_product": {
4414 "op": Op.REDUCE_PRODUCT,
4415 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004416 "build_fcn": (
4417 build_reduce,
4418 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004419 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004420 TosaArgGen.agAxis,
4421 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004422 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004423 "error_if_validators": (
4424 TosaErrorValidator.evAxisLargerRank,
4425 TosaErrorValidator.evAxisSmallerZero,
4426 TosaErrorValidator.evShapeOfAxisNotOne,
4427 TosaErrorValidator.evWrongInputType,
4428 TosaErrorValidator.evWrongOutputType,
4429 TosaErrorValidator.evWrongRank,
4430 TosaErrorValidator.evWrongInputList,
4431 TosaErrorValidator.evWrongOutputList,
4432 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004433 "data_gen": {
4434 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4435 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004436 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004437 "reduce_sum": {
4438 "op": Op.REDUCE_SUM,
4439 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004440 "build_fcn": (
4441 build_reduce,
4442 TosaTensorGen.tgBasic,
4443 TosaTensorValuesGen.tvgReduceSum,
4444 TosaArgGen.agAxis,
4445 ),
James Ward24dbc422022-10-19 12:20:31 +01004446 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004447 "error_if_validators": (
4448 TosaErrorValidator.evAxisLargerRank,
4449 TosaErrorValidator.evAxisSmallerZero,
4450 TosaErrorValidator.evShapeOfAxisNotOne,
4451 TosaErrorValidator.evWrongInputType,
4452 TosaErrorValidator.evWrongOutputType,
4453 TosaErrorValidator.evWrongRank,
4454 TosaErrorValidator.evWrongInputList,
4455 TosaErrorValidator.evWrongOutputList,
4456 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004457 "data_gen": {
4458 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4459 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004460 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004461 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004462 "concat": {
4463 "op": Op.CONCAT,
4464 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004465 "build_fcn": (
4466 build_concat,
4467 TosaTensorGen.tgConcat,
4468 TosaTensorValuesGen.tvgConcat,
4469 TosaArgGen.agAxis,
4470 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004471 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004472 "error_if_validators": (
4473 TosaErrorValidator.evAxisLargerRank,
4474 TosaErrorValidator.evAxisSmallerZero,
4475 TosaErrorValidator.evConcatInputRankMismatch,
4476 TosaErrorValidator.evConcatShapeSumMismatch,
4477 TosaErrorValidator.evConcatInputDimMismatch,
4478 TosaErrorValidator.evWrongInputType,
4479 TosaErrorValidator.evWrongOutputType,
4480 TosaErrorValidator.evWrongOutputList,
4481 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004482 "data_gen": {
4483 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4484 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004485 },
4486 "pad": {
4487 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004488 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004489 "build_fcn": (
4490 build_pad,
4491 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004492 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004493 TosaArgGen.agPad,
4494 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004495 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004496 "error_if_validators": (
4497 TosaErrorValidator.evWrongInputType,
4498 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004499 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004500 TosaErrorValidator.evWrongOutputType,
4501 TosaErrorValidator.evWrongInputList,
4502 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004503 TosaErrorValidator.evRankMismatch,
4504 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004505 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004506 "data_gen": {
4507 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4508 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004509 },
Won Jeona21b2e82023-08-10 10:33:01 +00004510 "dim": {
4511 "op": Op.DIM,
4512 "operands": (1, 0),
4513 "build_fcn": (
4514 build_dim,
4515 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004516 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004517 TosaArgGen.agAxis,
4518 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004519 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004520 "error_if_validators": (
4521 TosaErrorValidator.evAxisLargerRank,
4522 TosaErrorValidator.evAxisSmallerZero,
4523 TosaErrorValidator.evWrongInputType,
4524 TosaErrorValidator.evWrongInputList,
4525 TosaErrorValidator.evWrongOutputList,
4526 TosaErrorValidator.evWrongRank,
4527 ),
4528 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004529 "reshape": {
4530 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004531 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004532 "build_fcn": (
4533 build_reshape,
4534 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004535 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004536 TosaArgGen.agReshape,
4537 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004538 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004539 "error_if_validators": (
4540 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4541 TosaErrorValidator.evWrongInputType,
4542 TosaErrorValidator.evWrongOutputType,
4543 TosaErrorValidator.evWrongInputList,
4544 TosaErrorValidator.evWrongOutputList,
4545 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004546 "data_gen": {
4547 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4548 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004549 },
4550 "reverse": {
4551 "op": Op.REVERSE,
4552 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004553 "build_fcn": (
4554 build_reverse,
4555 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004556 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004557 TosaArgGen.agAxis,
4558 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004559 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004560 "error_if_validators": (
4561 TosaErrorValidator.evAxisSmallerZero,
4562 TosaErrorValidator.evAxisLargerRank,
4563 TosaErrorValidator.evWrongInputType,
4564 TosaErrorValidator.evWrongOutputType,
4565 TosaErrorValidator.evWrongInputList,
4566 TosaErrorValidator.evWrongOutputList,
4567 ),
evacha0198477222024-01-26 12:25:32 +00004568 "data_gen": {
4569 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4570 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004571 },
4572 "slice": {
4573 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004574 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004575 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004576 "build_fcn": (
4577 build_slice,
4578 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004579 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004580 TosaArgGen.agSlice,
4581 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004582 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004583 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004584 # TODO Turn off these error categories for now as the reference
4585 # model cannot allocate memory space for empty tensor. We probably
4586 # can report an accurate error messege at the right place during
4587 # exeuction.
4588 # TosaErrorValidator.evStartSmallerZero,
4589 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004590 TosaErrorValidator.evStartSizeOutsideBounds,
4591 TosaErrorValidator.evSizeOutputShapeMismatch,
4592 TosaErrorValidator.evInputSizeStartLengthMismatch,
4593 TosaErrorValidator.evWrongRank,
4594 TosaErrorValidator.evWrongInputType,
4595 TosaErrorValidator.evWrongOutputType,
4596 TosaErrorValidator.evWrongInputList,
4597 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004598 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004599 ),
evacha017f7d4252024-01-24 12:08:09 +00004600 "data_gen": {
4601 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4602 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004603 },
4604 "tile": {
4605 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004606 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004607 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004608 "build_fcn": (
4609 build_tile,
4610 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004611 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004612 TosaArgGen.agTile,
4613 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004614 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004615 "error_if_validators": (
4616 TosaErrorValidator.evWrongInputType,
4617 TosaErrorValidator.evWrongOutputType,
4618 TosaErrorValidator.evWrongInputList,
4619 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004620 TosaErrorValidator.evRankMismatch,
4621 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004622 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004623 "data_gen": {
4624 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4625 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004626 },
4627 "transpose": {
4628 "op": Op.TRANSPOSE,
4629 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004630 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004631 "build_fcn": (
4632 build_transpose,
4633 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004634 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004635 TosaArgGen.agTranspose,
4636 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004637 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004638 "error_if_validators": (
4639 TosaErrorValidator.evIndexOutsideBounds,
4640 TosaErrorValidator.evIndexUsedTwice,
4641 TosaErrorValidator.evWrongInputType,
4642 TosaErrorValidator.evWrongOutputType,
4643 TosaErrorValidator.evWrongInputList,
4644 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004645 TosaErrorValidator.evWrongRank,
4646 TosaErrorValidator.evRankMismatch,
4647 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004648 ),
evacha0198477222024-01-26 12:25:32 +00004649 "data_gen": {
4650 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4651 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004652 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004653 # Data nodes
4654 "const": {
4655 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004656 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004657 "build_fcn": (
4658 build_const,
4659 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004660 TosaTensorValuesGen.tvgLazyGenDefault,
4661 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004662 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004663 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004664 "data_gen": {
4665 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4666 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004668 "identity": {
4669 "op": Op.IDENTITY,
4670 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004671 "build_fcn": (
4672 build_unary,
4673 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004674 TosaTensorValuesGen.tvgLazyGenDefault,
4675 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004676 ),
evacha011adff832024-03-06 17:33:44 +00004677 "types": TYPE_FIB + [DType.INT4, DType.INT48],
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004678 "data_gen": {
4679 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4680 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004681 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004682 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004683 "gather": {
4684 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004685 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004686 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004687 "build_fcn": (
4688 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004689 TosaTensorGen.tgGather,
4690 TosaTensorValuesGen.tvgGather,
4691 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004692 ),
James Ward24dbc422022-10-19 12:20:31 +01004693 "types": (
4694 DType.INT8,
4695 DType.INT16,
4696 DType.INT32,
4697 DType.FP16,
4698 DType.BF16,
4699 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004700 DType.FP8E4M3,
4701 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004702 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004703 "error_if_validators": (
4704 TosaErrorValidator.evWrongInputType,
4705 TosaErrorValidator.evWrongOutputType,
4706 TosaErrorValidator.evWrongInputList,
4707 TosaErrorValidator.evWrongOutputList,
4708 TosaErrorValidator.evWrongRank,
4709 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004710 "data_gen": {
4711 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4712 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004713 },
4714 "scatter": {
4715 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004716 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004717 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004718 "build_fcn": (
4719 build_scatter,
4720 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004721 TosaTensorValuesGen.tvgScatter,
4722 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004723 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004724 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004725 "error_if_validators": (
4726 TosaErrorValidator.evWrongInputType,
4727 TosaErrorValidator.evWrongOutputType,
4728 TosaErrorValidator.evWrongInputList,
4729 TosaErrorValidator.evWrongOutputList,
4730 TosaErrorValidator.evWrongRank,
4731 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004732 "data_gen": {
4733 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4734 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004735 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004736 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004737 "resize": {
4738 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004739 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004740 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004741 "build_fcn": (
4742 build_resize,
4743 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004744 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004745 TosaArgGen.agResize,
4746 ),
James Ward24dbc422022-10-19 12:20:31 +01004747 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004748 "invalid_test_validators": (
4749 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004750 ),
4751 "error_if_validators": (
4752 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004753 TosaErrorValidator.evScaleSmallerEqualZero,
4754 TosaErrorValidator.evScaleNLargerMax,
4755 TosaErrorValidator.evScaleDLargerMax,
4756 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004757 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004758 TosaErrorValidator.evBorderSmallerMin,
4759 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004760 TosaErrorValidator.evWrongInputType,
4761 TosaErrorValidator.evWrongOutputType,
4762 TosaErrorValidator.evWrongRank,
4763 TosaErrorValidator.evWrongInputList,
4764 TosaErrorValidator.evWrongOutputList,
4765 TosaErrorValidator.evBatchMismatch,
4766 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004767 TosaErrorValidator.evResizeOutputShapeMismatch,
4768 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004769 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004770 "data_gen": {
4771 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4772 },
4773 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004774 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004775 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004776 "cast": {
4777 "op": Op.CAST,
4778 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004779 "build_fcn": (
4780 build_cast,
4781 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004782 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004783 TosaArgGen.agCast,
4784 ),
James Ward8b390432022-08-12 20:48:56 +01004785 "types": (
4786 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004787 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004788 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004789 DType.INT8,
4790 DType.INT16,
4791 DType.INT32,
4792 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004793 DType.FP8E4M3,
4794 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004795 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004796 "error_if_validators": (
4797 TosaErrorValidator.evWrongInputType,
4798 TosaErrorValidator.evWrongOutputType,
4799 TosaErrorValidator.evWrongInputList,
4800 TosaErrorValidator.evWrongOutputList,
4801 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004802 "data_gen": {
4803 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4804 },
4805 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004806 },
4807 "rescale": {
4808 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004809 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004810 "build_fcn": (
4811 build_rescale,
4812 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004813 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004814 TosaArgGen.agRescale,
4815 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004816 "types": [
4817 DType.UINT8,
4818 DType.INT8,
4819 DType.INT16,
4820 DType.INT32,
4821 DType.INT48,
4822 DType.UINT16,
4823 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004824 "error_if_validators": (
4825 TosaErrorValidator.evInputZeroPointNotZero,
4826 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004827 TosaErrorValidator.evU16InputZeroPointNotValid,
4828 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004829 TosaErrorValidator.evScaleTrue,
4830 TosaErrorValidator.evScaleNotTrue,
4831 TosaErrorValidator.evWrongInputType,
4832 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004833 TosaErrorValidator.evWrongInputList,
4834 TosaErrorValidator.evWrongOutputList,
4835 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004836 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004837 # Custom
4838 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004839 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004840 # Two varients of cond_if, one that generates one of two constant tensors (no
4841 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4842 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004843 "cond_if_const": {
4844 "op": Op.COND_IF,
4845 "operands": (0, 2),
4846 "build_fcn": (
4847 build_cond_if_const,
4848 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004849 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004850 TosaArgGen.agCondIf,
4851 ),
4852 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004853 "error_if_validators": (
4854 TosaErrorValidator.evOutputListThenGraphMismatch,
4855 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004856 TosaErrorValidator.evCondIfCondNotMatchingBool,
4857 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004858 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004859 },
4860 "cond_if_binary": {
4861 "op": Op.COND_IF,
4862 "operands": (2, 0),
4863 "build_fcn": (
4864 build_cond_if_binary,
4865 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004866 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004867 TosaArgGen.agCondIf,
4868 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004869 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004870 "error_if_validators": (
4871 TosaErrorValidator.evInputListThenGraphMismatch,
4872 TosaErrorValidator.evInputListElseGraphMismatch,
4873 TosaErrorValidator.evOutputListThenGraphMismatch,
4874 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004875 TosaErrorValidator.evCondIfCondNotMatchingBool,
4876 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004877 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004878 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004879 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004880 "while_loop": {
4881 "op": Op.WHILE_LOOP,
4882 "operands": (0, 1),
4883 "build_fcn": (
4884 build_while_loop,
4885 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004886 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004887 TosaArgGen.agWhileLoop,
4888 ),
4889 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004890 "error_if_validators": (
4891 TosaErrorValidator.evInputListOutputListMismatch,
4892 TosaErrorValidator.evInputListCondGraphMismatch,
4893 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4894 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4895 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004896 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004897 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004898 },
Luke Hutton57287132023-02-06 14:54:18 +00004899 "fft2d": {
4900 "op": Op.FFT2D,
4901 "operands": (2, 0),
4902 "rank": (3, 3),
4903 "build_fcn": (
4904 build_fft2d,
4905 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004906 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004907 TosaArgGen.agFFT2d,
4908 ),
4909 "types": [DType.FP32],
4910 "error_if_validators": (
4911 TosaErrorValidator.evWrongInputType,
4912 TosaErrorValidator.evWrongOutputType,
4913 TosaErrorValidator.evWrongInputList,
4914 TosaErrorValidator.evWrongOutputList,
4915 TosaErrorValidator.evWrongRank,
4916 TosaErrorValidator.evBatchMismatch,
4917 TosaErrorValidator.evKernelNotPowerOfTwo,
4918 TosaErrorValidator.evFFTInputShapeMismatch,
4919 TosaErrorValidator.evFFTOutputShapeMismatch,
4920 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004921 "data_gen": {
4922 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4923 },
Luke Hutton57287132023-02-06 14:54:18 +00004924 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004925 "rfft2d": {
4926 "op": Op.RFFT2D,
4927 "operands": (1, 0),
4928 "rank": (3, 3),
4929 "build_fcn": (
4930 build_rfft2d,
4931 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004932 TosaTensorValuesGen.tvgLazyGenDefault,
4933 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004934 ),
4935 "types": [DType.FP32],
4936 "error_if_validators": (
4937 TosaErrorValidator.evWrongInputType,
4938 TosaErrorValidator.evWrongOutputType,
4939 TosaErrorValidator.evWrongInputList,
4940 TosaErrorValidator.evWrongOutputList,
4941 TosaErrorValidator.evWrongRank,
4942 TosaErrorValidator.evBatchMismatch,
4943 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004944 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004945 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004946 "data_gen": {
4947 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4948 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004949 },
Won Jeon74342e52024-01-09 00:34:40 +00004950 # Shape
4951 "add_shape": {
4952 "op": Op.ADD_SHAPE,
4953 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004954 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004955 "build_fcn": (
4956 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004957 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004958 TosaTensorValuesGen.tvgAddSub,
4959 TosaArgGen.agNone,
4960 ),
4961 "types": [DType.SHAPE],
4962 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4963 },
4964 "sub_shape": {
4965 "op": Op.SUB_SHAPE,
4966 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004967 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004968 "build_fcn": (
4969 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004970 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004971 TosaTensorValuesGen.tvgAddSub,
4972 TosaArgGen.agNone,
4973 ),
4974 "types": [DType.SHAPE],
4975 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4976 },
4977 "mul_shape": {
4978 "op": Op.MUL_SHAPE,
4979 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004980 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004981 "build_fcn": (
4982 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004983 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004984 TosaTensorValuesGen.tvgMul,
4985 TosaArgGen.agNone,
4986 ),
4987 "types": [DType.SHAPE],
4988 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4989 },
4990 "div_shape": {
4991 "op": Op.DIV_SHAPE,
4992 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004993 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004994 "build_fcn": (
4995 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004996 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004997 TosaTensorValuesGen.tvgIntDiv,
4998 TosaArgGen.agNone,
4999 ),
5000 "types": [DType.SHAPE],
5001 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5002 },
5003 "concat_shape": {
5004 "op": Op.CONCAT_SHAPE,
5005 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005006 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005007 "build_fcn": (
5008 build_concat,
5009 TosaTensorGen.tgConcat,
5010 TosaTensorValuesGen.tvgConcat,
5011 TosaArgGen.agNone,
5012 ),
5013 "types": [DType.SHAPE],
5014 "error_if_validators": (),
5015 },
5016 "const_shape": {
5017 "op": Op.CONST_SHAPE,
5018 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005019 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005020 "build_fcn": (
5021 build_const,
5022 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005023 TosaTensorValuesGen.tvgLazyGenDefault,
5024 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005025 ),
5026 "types": [DType.SHAPE],
5027 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005028 }
5029
Kevin Cheng550ccc52021-03-03 11:21:43 -08005030
Eric Kunzee5e26762020-10-13 16:11:07 -07005031class OutputShaper:
5032 # Methods in this class compute the expected output shape and datatype
5033 # for common classes of operations
5034 def __init__(self):
5035 pass
5036
5037 # These methods return arguments that can be used for
5038 # creating a new output tensor
5039 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005040 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5041 if error_name != ErrorIf.RankMismatch:
5042 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005043 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005044
5045 shape = []
5046 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005047 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005048 shape.append(b.shape[i])
5049 else:
5050 shape.append(a.shape[i])
5051
Jerry Ge135c9552023-05-23 20:59:32 +00005052 fuzz_idx = rng.integers(0, len(a.shape))
5053 if error_name == ErrorIf.DimensionMismatch:
5054 shape[fuzz_idx] += 1
5055
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005056 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005057 all_dtypes = [
5058 DType.INT8,
5059 DType.INT16,
5060 DType.INT32,
5061 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005062 DType.FP16,
5063 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005064 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005065 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005066 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5067 outputDType = rng.choice(wrong_dtypes)
5068 else:
5069 outputDType = a.dtype
5070
5071 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005072
5073 @staticmethod
5074 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005075 assert len(a.shape) == len(b.shape)
5076 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005077
5078 shape = []
5079 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005080 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005081 shape.append(a.shape[i])
5082
Kevin Cheng550ccc52021-03-03 11:21:43 -08005083 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005084
5085 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005086 def unaryOp(ser, rng, a, error_name=None):
5087 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005088 all_dtypes = [
5089 DType.INT8,
5090 DType.INT16,
5091 DType.INT32,
5092 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005093 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005094 DType.FP16,
5095 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005096 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005097 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5098 outputDType = rng.choice(wrong_dtypes)
5099 else:
5100 outputDType = a.dtype
5101
5102 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005103
5104 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005105 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005106 if error_name != ErrorIf.RankMismatch:
5107 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005108 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005109
5110 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005111 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005112 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005113 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5114 else:
5115 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005116
Jerry Ge135c9552023-05-23 20:59:32 +00005117 fuzz_idx = rng.integers(0, len(a.shape))
5118 if error_name == ErrorIf.DimensionMismatch:
5119 shape[fuzz_idx] += 1
5120
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005121 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005122 all_dtypes = [
5123 DType.INT8,
5124 DType.INT16,
5125 DType.INT32,
5126 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005127 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005128 DType.FP16,
5129 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005130 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005131 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5132 outputDType = rng.choice(wrong_dtypes)
5133 else:
5134 outputDType = a.dtype
5135
5136 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005137
5138 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005139 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005140 if error_name != ErrorIf.RankMismatch:
5141 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005142 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005143
5144 # Do broadcast
5145 shape = []
5146 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005147 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005148 shape.append(b.shape[i])
5149 else:
5150 shape.append(a.shape[i])
5151
Jerry Ge135c9552023-05-23 20:59:32 +00005152 fuzz_idx = rng.integers(0, len(a.shape))
5153 if error_name == ErrorIf.DimensionMismatch:
5154 shape[fuzz_idx] += 1
5155
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005156 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005157 wrong_dtypes = [
5158 DType.INT8,
5159 DType.INT16,
5160 DType.INT32,
5161 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005162 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005163 DType.FP16,
5164 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005165 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005166 outputDType = rng.choice(wrong_dtypes)
5167 else:
5168 outputDType = DType.BOOL
5169
5170 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005171
5172 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005173 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005174 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005175 if error_name not in [
5176 ErrorIf.AxisSmallerZero,
5177 ErrorIf.AxisLargerRank,
5178 ErrorIf.ShapeOfAxisNotOne,
5179 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005180 shape[axis] = 1
5181 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5182 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005183
Matthew Haddond6ce7252021-09-29 15:35:44 +01005184 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005185 all_dtypes = [
5186 DType.INT8,
5187 DType.INT16,
5188 DType.INT32,
5189 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005190 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005191 DType.FP16,
5192 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005193 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005194 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5195 outputDType = rng.choice(wrong_dtypes)
5196 else:
5197 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005198
Matthew Haddond6ce7252021-09-29 15:35:44 +01005199 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005200
5201 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005202 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005203 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005204
5205 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5206 del shape[axis]
5207
5208 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5209 remove = rng.choice([True, False])
5210 if remove and len(shape) > 1:
5211 del shape[0]
5212 else:
5213 shape.append(1)
5214 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5215 for i in range(len(shape)):
5216 shape[i] = shape[i] + rng.integers(1, 10)
5217
5218 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005219 all_dtypes = [
5220 DType.INT8,
5221 DType.INT16,
5222 DType.INT32,
5223 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005224 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005225 DType.FP16,
5226 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005227 DType.FP8E4M3,
5228 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005229 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005230 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5231 outputDType = rng.choice(wrong_dtypes)
5232 else:
5233 outputDType = DType.INT32
5234
5235 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005236
5237 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005238 def conv2dOp(
5239 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5240 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005241
5242 # IFM: NHWC
5243 # Filter: OHWI
5244 # OFM: NHWC
5245
Kevin Cheng550ccc52021-03-03 11:21:43 -08005246 h = (
5247 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005248 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005249 + padding[0]
5250 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005251 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005252 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005253
Kevin Cheng550ccc52021-03-03 11:21:43 -08005254 w = (
5255 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005256 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005257 + padding[2]
5258 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005259 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005260 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005261
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005262 if error_name == ErrorIf.ConvOutputShapeMismatch:
5263 choices = [1, 2, 3]
5264 change = rng.choice(choices)
5265 # increment in multiples of stride to not hit non-integer error case
5266 if change in [1, 3]:
5267 h = h + (rng.choice(choices) * strides[0])
5268 if change in [2, 3]:
5269 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005270
Eric Kunzee5e26762020-10-13 16:11:07 -07005271 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5272
James Ward8b390432022-08-12 20:48:56 +01005273 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005274 # Pick some potentially correct output dtype if input type is incorrect
5275 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005276 else:
James Ward8b390432022-08-12 20:48:56 +01005277 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005278
5279 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005280 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005281 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005282 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005283 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005284 else:
5285 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005286 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005287 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005288
Kevin Cheng550ccc52021-03-03 11:21:43 -08005289 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005290
5291 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005292 def conv3dOp(
5293 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5294 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005295
5296 # IFM: NDHWC
5297 # Filter: ODHWI
5298 # OFM: NDHWC
5299
5300 d = (
5301 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005302 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005303 + padding[0]
5304 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005305 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005306 ) // strides[0] + 1
5307
5308 h = (
5309 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005310 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005311 + padding[2]
5312 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005313 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005314 ) // strides[1] + 1
5315
5316 w = (
5317 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005318 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005319 + padding[4]
5320 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005321 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005322 ) // strides[2] + 1
5323
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005324 if error_name == ErrorIf.ConvOutputShapeMismatch:
5325 choices = [1, 2, 3, 4]
5326 change = rng.choice(choices)
5327 # increment in multiples of stride to not hit non-integer error case
5328 if change in [1, 4]:
5329 d = d + (rng.choice(choices) * strides[0])
5330 if change in [2, 4]:
5331 h = h + (rng.choice(choices) * strides[1])
5332 if change in [3, 4]:
5333 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005334
Kevin Cheng1533b852021-09-01 12:51:58 -07005335 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5336
James Ward8b390432022-08-12 20:48:56 +01005337 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005338 # Pick some potentially correct output dtype if input type is incorrect
5339 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005340 else:
James Ward8b390432022-08-12 20:48:56 +01005341 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005342
5343 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005344 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005345 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005346 else:
5347 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005348 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005349 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005350
5351 return ser.addOutput(ofm_shape, out_dtype)
5352
5353 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005354 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005355 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005356 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005357 # IFM: NHWC
5358 # Filter: HWCM
5359 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005360
Kevin Cheng550ccc52021-03-03 11:21:43 -08005361 h = (
5362 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005363 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005364 + padding[0]
5365 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005366 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005367 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005368
Kevin Cheng550ccc52021-03-03 11:21:43 -08005369 w = (
5370 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005371 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005372 + padding[2]
5373 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005374 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005375 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005376
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005377 if error_name == ErrorIf.ConvOutputShapeMismatch:
5378 choices = [1, 2, 3]
5379 change = rng.choice(choices)
5380 # increment in multiples of stride to not hit non-integer error case
5381 if change in [1, 3]:
5382 h = h + (rng.choice(choices) * strides[0])
5383 if change in [2, 3]:
5384 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005385
Eric Kunzee5e26762020-10-13 16:11:07 -07005386 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5387
James Ward8b390432022-08-12 20:48:56 +01005388 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005389 # Pick some potentially correct output dtype if input type is incorrect
5390 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005391 else:
James Ward8b390432022-08-12 20:48:56 +01005392 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005393
5394 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005395 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005396 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005397 else:
5398 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005399 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005400 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005401
Kevin Cheng550ccc52021-03-03 11:21:43 -08005402 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005403
5404 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005405 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005406 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005407 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005408 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005409 h = 1
5410 w = 1
5411 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005412 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5413 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005414
5415 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005416 choices = [1, 2, 3]
5417 change = rng.choice(choices)
5418 # increment in multiples of stride to not hit non-integer error case
5419 if change in [1, 3]:
5420 h = h + (rng.choice(choices) * stride[0])
5421 if change in [2, 3]:
5422 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005423 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005424
5425 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005426 all_dtypes = [
5427 DType.INT8,
5428 DType.INT16,
5429 DType.INT32,
5430 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005431 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005432 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005433 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005434 DType.FP8E4M3,
5435 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005436 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005437 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5438 outputDType = rng.choice(wrong_dtypes)
5439 else:
5440 outputDType = ifm.dtype
5441
5442 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005443
5444 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005445 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005446 # input: N, IC
5447 # filter: OC, IC
5448 # output: N, OC
5449
5450 output_shape = [input.shape[0], filter.shape[0]]
5451
James Ward8b390432022-08-12 20:48:56 +01005452 # Validated in arg_gen (also invalidated for ErrorIf)
5453 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005454
Kevin Cheng550ccc52021-03-03 11:21:43 -08005455 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005456
5457 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005458 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005459 # a: N, H, C
5460 # b: N, C, W
5461 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005462
Kevin Cheng2d60f002021-06-09 14:18:32 -07005463 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005464
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005465 if error_name == ErrorIf.WrongOutputType:
5466 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005467 incorrect_types = (
5468 DType.INT4,
5469 DType.INT8,
5470 DType.INT16,
5471 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005472 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005473 DType.FP16,
5474 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005475 DType.FP8E4M3,
5476 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005477 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005478 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005479 incorrect_types = (
5480 DType.INT4,
5481 DType.INT8,
5482 DType.INT16,
5483 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005484 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005485 DType.FP16,
5486 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005487 DType.FP8E4M3,
5488 DType.FP8E5M2,
5489 )
5490 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5491 incorrect_types = (
5492 DType.INT4,
5493 DType.INT8,
5494 DType.INT16,
5495 DType.INT32,
5496 DType.INT48,
5497 DType.FP32,
5498 DType.BF16,
5499 DType.FP8E4M3,
5500 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005501 )
James Ward24dbc422022-10-19 12:20:31 +01005502 elif (
5503 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5504 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005505 incorrect_types = (
5506 DType.INT4,
5507 DType.INT8,
5508 DType.INT16,
5509 DType.INT32,
5510 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005511 DType.FP8E4M3,
5512 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005513 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005514 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005515 elif error_name == ErrorIf.WrongInputType:
5516 # Pick some potentially correct output dtype if input type is incorrect
5517 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005518 else:
James Ward8b390432022-08-12 20:48:56 +01005519 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005520
Kevin Cheng550ccc52021-03-03 11:21:43 -08005521 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005522
5523 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005524 def concatOp(ser, rng, axis, inputs, error_name=None):
5525 input1 = inputs[0]
5526 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005527
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005528 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005529 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005530 if not (
5531 # unable to concat tensors of different ranks
5532 error_name == ErrorIf.ConcatInputRankMismatch
5533 # unable to concat tensors along an invalid axis
5534 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005535 ):
5536 for tensor in remaining_inputs:
5537 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005538
Matthew Haddon01c359d2021-10-15 16:30:48 +01005539 if error_name == ErrorIf.ConcatShapeSumMismatch:
5540 output_shape[axis] += rng.integers(5, 10)
5541
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005542 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005543 all_dtypes = {
5544 DType.INT8,
5545 DType.INT16,
5546 DType.INT32,
5547 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005548 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005549 DType.FP16,
5550 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005551 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005552 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5553 outputDType = rng.choice(wrong_dtypes)
5554 else:
5555 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005556
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005557 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005558
5559 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005560 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005561
5562 output_shape = a.shape.copy()
5563
5564 for i in range(len(output_shape)):
5565 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5566
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005567 if error_name == ErrorIf.PadOutputShapeMismatch:
5568 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005569 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005570 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005571 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005572
Matthew Haddone807aae2021-10-11 18:12:58 +01005573 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005574 all_dtypes = [
5575 DType.INT8,
5576 DType.INT16,
5577 DType.INT32,
5578 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005579 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005580 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005581 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005582 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005583 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5584 outputDType = rng.choice(wrong_dtypes)
5585 else:
5586 outputDType = a.dtype
5587
5588 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005589
5590 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005591 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005592 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005593
5594 if error_name == ErrorIf.WrongOutputType:
5595 all_dtypes = [
5596 DType.INT8,
5597 DType.INT16,
5598 DType.INT32,
5599 DType.INT48,
5600 DType.FP32,
5601 DType.FP16,
5602 DType.BF16,
5603 ]
5604 wrong_dtypes = list(set(all_dtypes))
5605 outputDType = rng.choice(wrong_dtypes)
5606 else:
5607 outputDType = DType.SHAPE
5608
5609 return ser.addOutput(output_shape, outputDType)
5610
5611 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005612 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005613 output_shape = shape.copy()
5614
Matthew Haddone807aae2021-10-11 18:12:58 +01005615 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5616 for i in range(len(output_shape)):
5617 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5618
5619 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005620 all_dtypes = [
5621 DType.INT8,
5622 DType.INT16,
5623 DType.INT32,
5624 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005625 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005626 DType.FP16,
5627 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005628 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005629 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5630 outputDType = rng.choice(wrong_dtypes)
5631 else:
5632 outputDType = a.dtype
5633
5634 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005635
5636 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005637 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005638
Matthew Haddone807aae2021-10-11 18:12:58 +01005639 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005640 all_dtypes = [
5641 DType.INT8,
5642 DType.INT16,
5643 DType.INT32,
5644 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005645 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005646 DType.FP16,
5647 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005648 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005649 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005650 outputDType = rng.choice(wrong_dtypes)
5651 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005652 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005653
Luke Huttona4e48ca2023-02-22 11:53:48 +00005654 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005655 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005656 for index in range(len(output_shape)):
5657 if output_shape[index] <= 2:
5658 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5659 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005660 output_shape[index] = output_shape[index] + rng.choice(
5661 [-2, -1, 1, 2]
5662 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005663 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5664 output_shape = input.shape.copy()
5665 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005666 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005667
5668 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005669
5670 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005671 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005672
5673 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005674 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005675
5676 for i in range(len(output_shape)):
5677 output_shape[i] = a.shape[i] * multiples[i]
5678
Luke Huttona4e48ca2023-02-22 11:53:48 +00005679 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005680 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005681
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005682 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005683 all_dtypes = [
5684 DType.INT8,
5685 DType.INT16,
5686 DType.INT32,
5687 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005688 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005689 DType.FP16,
5690 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005691 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005692 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5693 outputDType = rng.choice(wrong_dtypes)
5694 else:
5695 outputDType = a.dtype
5696
5697 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005698
5699 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005700 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005701 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005702
Kevin Cheng550ccc52021-03-03 11:21:43 -08005703 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005704
Luke Huttona4e48ca2023-02-22 11:53:48 +00005705 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005706 for i in range(len(output_shape)):
5707 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005708
Luke Huttona4e48ca2023-02-22 11:53:48 +00005709 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5710 for i in range(len(output_shape)):
5711 output_shape[i] += rng.integers(1, 10)
5712 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005713 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005714
Matthew Haddone807aae2021-10-11 18:12:58 +01005715 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005716 all_dtypes = [
5717 DType.INT8,
5718 DType.INT16,
5719 DType.INT32,
5720 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005721 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005722 DType.FP16,
5723 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005724 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005725 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5726 outputDType = rng.choice(wrong_dtypes)
5727 else:
5728 outputDType = a.dtype
5729
5730 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005731
5732 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005733 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005734 if error_name != ErrorIf.WrongRank:
5735 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005736 assert len(indices.shape) == 2
5737 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005738
Kevin Cheng77d0f762020-11-24 10:26:32 -08005739 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5740
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005741 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005742 all_dtypes = [
5743 DType.INT8,
5744 DType.INT16,
5745 DType.INT32,
5746 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005747 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005748 DType.FP16,
5749 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005750 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005751 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5752 outputDType = rng.choice(wrong_dtypes)
5753 else:
5754 outputDType = values.dtype
5755
5756 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005757
5758 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005759 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005760 if error_name != ErrorIf.WrongRank:
5761 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005762 assert len(indices.shape) == 2
5763 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005764 assert values_in.shape[0] == indices.shape[0] # N
5765 assert input.shape[1] == indices.shape[1] # W
5766 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005767
5768 output_shape = values_in.shape
5769
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005770 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005771 all_dtypes = [
5772 DType.INT8,
5773 DType.INT16,
5774 DType.INT32,
5775 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005776 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005777 DType.FP16,
5778 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005779 DType.FP8E4M3,
5780 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005781 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005782 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5783 outputDType = rng.choice(wrong_dtypes)
5784 else:
5785 outputDType = values_in.dtype
5786
5787 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005788
5789 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005790 def tableOp(ser, rng, input, error_name=None):
5791 # Same shape as the input, dtype dependent on input dtype
5792 if error_name != ErrorIf.WrongInputType:
5793 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005794 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005795 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005796 wrong_dtypes = [
5797 DType.INT8,
5798 DType.INT16,
5799 DType.INT32,
5800 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005801 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005802 DType.FP16,
5803 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005804 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005805 wrong_dtypes.remove(output_dtype)
5806 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005807 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005808
5809 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005810 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005811 serializer,
5812 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005813 input,
5814 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005815 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005816 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005817 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005818 input_dtype,
5819 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005820 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005821 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005822 # Calculate OH, OW
5823 scale_y_n = scale[0]
5824 scale_y_d = scale[1]
5825 scale_x_n = scale[2]
5826 scale_x_d = scale[3]
5827 if error_name == ErrorIf.ScaleSmallerEqualZero:
5828 scale_y_n = max(scale_y_n, 1)
5829 scale_y_d = max(scale_y_d, 1)
5830 scale_x_n = max(scale_x_n, 1)
5831 scale_x_d = max(scale_x_d, 1)
5832
5833 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5834 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5835
5836 if error_name is not None:
5837 # Make sure the output tensor is valid, which can occur when
5838 # scale, offset or border have been changed for ERROR_IFs
5839 oh = max(oh, 1)
5840 ow = max(ow, 1)
5841 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005842 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5843 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005844
5845 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5846 choices = [1, 2, 3]
5847 change = rng.choice(choices)
5848 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5849 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005850 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005851 oh -= scale_y_d
5852 assert oh > 0 # Should have been caught in agResize
5853 else:
5854 oh += scale_y_d
5855 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005856 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005857 ow -= scale_x_d
5858 assert ow > 0 # Should have been caught in agResize
5859 else:
5860 ow += scale_x_d
5861
Matthew Haddon848efb42021-09-09 12:30:53 +01005862 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005863 output_dims = [
5864 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005865 oh,
5866 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005867 input.shape[0],
5868 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005869 elif error_name == ErrorIf.BatchMismatch:
5870 output_dims = [
5871 input.shape[0] + rng.integers(1, 10),
5872 oh,
5873 ow,
5874 input.shape[3],
5875 ]
5876 elif error_name == ErrorIf.ChannelMismatch:
5877 output_dims = [
5878 input.shape[0],
5879 oh,
5880 ow,
5881 input.shape[3] + rng.integers(1, 10),
5882 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005883 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005884 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005885
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005886 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005887
5888 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005889 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005890 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005891
5892 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005893 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005894 if error_name == ErrorIf.ConvOutputShapeMismatch:
5895 choices = [1, 2, 3]
5896 change = rng.choice(choices)
5897 if change in [1, 3]:
5898 output_shape[1] = output_shape[1] + rng.choice(choices)
5899 if change in [2, 3]:
5900 output_shape[2] = output_shape[2] + rng.choice(choices)
5901
James Ward8b390432022-08-12 20:48:56 +01005902 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005903 # Pick some potentially correct output dtype if input type is incorrect
5904 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005905 else:
James Ward8b390432022-08-12 20:48:56 +01005906 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005907
5908 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005909 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005910 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005911 else:
5912 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005913 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005914 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005915
Kevin Cheng550ccc52021-03-03 11:21:43 -08005916 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005917
5918 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005919 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5920 outputs = []
5921
5922 assert ifm1.dtype == ifm2.dtype
5923 input_dtype = ifm1.dtype
5924
5925 if error_name != ErrorIf.FFTInputShapeMismatch:
5926 assert ifm1.shape == ifm2.shape
5927
5928 input_shape = ifm1.shape
5929 if error_name != ErrorIf.WrongRank:
5930 assert len(input_shape) == 3
5931
5932 output_shape = input_shape.copy()
5933 output_dtype = input_dtype
5934
5935 if error_name == ErrorIf.WrongOutputType:
5936 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005937 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005938 output_dtype = rng.choice(wrong_dtypes)
5939 elif error_name == ErrorIf.BatchMismatch:
5940 output_shape[0] += rng.integers(1, 10)
5941 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5942 modify_dim = rng.choice([1, 2])
5943 output_shape[modify_dim] += rng.integers(1, 10)
5944
5945 outputs.append(serializer.addOutput(output_shape, output_dtype))
5946 outputs.append(serializer.addOutput(output_shape, output_dtype))
5947 return outputs
5948
5949 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005950 def rfft2dOp(serializer, rng, value, error_name=None):
5951 outputs = []
5952
5953 input_shape = value.shape
5954 if error_name != ErrorIf.WrongRank:
5955 assert len(input_shape) == 3
5956
5957 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5958
5959 output_dtype = value.dtype
5960 if error_name == ErrorIf.WrongOutputType:
5961 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005962 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005963 output_dtype = rng.choice(wrong_dtypes)
5964 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005965 output_shape[0] += rng.integers(1, 10)
5966 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5967 modify_dim = rng.choice([1, 2])
5968 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005969
5970 outputs.append(serializer.addOutput(output_shape, output_dtype))
5971 outputs.append(serializer.addOutput(output_shape, output_dtype))
5972 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005973
5974 @staticmethod
5975 def addShapeOp(ser, rng, a, b, error_name=None):
5976 if error_name != ErrorIf.RankMismatch:
5977 assert len(a.shape) == len(b.shape)
5978 assert a.dtype == b.dtype
5979
5980 shape = []
5981 for i in range(len(a.shape)):
5982 shape.append(a.shape[i])
5983
5984 fuzz_idx = rng.integers(0, len(a.shape))
5985 if error_name == ErrorIf.DimensionMismatch:
5986 shape[fuzz_idx] += 1
5987
5988 if error_name == ErrorIf.WrongOutputType:
5989 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5990 outputDType = rng.choice(wrong_dtypes)
5991 else:
5992 outputDType = DType.SHAPE
5993 return ser.addOutput(shape, outputDType)