blob: b5e71ac74d93e1706632c5b5ebea77745165463d [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Luke Huttona4e48ca2023-02-22 11:53:48 +000017from generator.tosa_utils import get_rank_mismatch_shape
Jeremy Johnson05c711e2022-12-12 18:00:41 +000018from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010019from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010021from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010027 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000028 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010029 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010030 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010031 TOSA_8K_LEVEL_MAX_KERNEL = 8192
32 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010033
Eric Kunzee5e26762020-10-13 16:11:07 -070034 def __init__(self, args):
35 self.args = args
36 self.basePath = args.output_dir
37 self.random_seed = args.random_seed
38 self.ser = None
39 self.rng = np.random.default_rng(self.random_seed)
40 self.createDynamicOpLists()
41 self.initOpListDefaults()
42 self.quantGen = TosaQuantGen()
43 # Force makeShape to do a specific starting shape
44 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010045 # Work out floating point range
46 self.random_fp_low = min(args.tensor_fp_value_range)
47 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070048
49 def createSerializer(self, opName, testPath):
50 self.testPath = os.path.join(opName, testPath)
51
52 fullPath = os.path.join(self.basePath, self.testPath)
53 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010054 # Embed const data in the flatbuffer
55 constMode = ts.ConstMode.EMBED
56 if self.args.dump_consts:
57 constMode = ts.ConstMode.EMBED_DUMP
58 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
60 def getSerializer(self):
61 return self.ser
62
63 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080064 with open(
65 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
66 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070067 fd.write(self.ser.serialize())
68
Kevin Cheng550ccc52021-03-03 11:21:43 -080069 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
70 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070071
Matthew Haddon74567092021-07-16 15:38:20 +010072 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000073 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010074 seed = self.random_seed + 1
75 self.rng = np.random.default_rng(seed)
76
Eric Kunzee5e26762020-10-13 16:11:07 -070077 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070078 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070079 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070080 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070081 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070082 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070083 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010084 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
85 elif dtype == DType.UINT8:
86 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070087 elif dtype == DType.INT16:
88 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010089 elif dtype == DType.UINT16:
90 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070091 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080092 return np.int32(
93 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
94 )
Eric Kunzee5e26762020-10-13 16:11:07 -070095 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080096 return np.int64(
97 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
98 )
James Ward8b390432022-08-12 20:48:56 +010099 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100100 return np.float16(
101 self.rng.uniform(
102 low=self.random_fp_low, high=self.random_fp_high, size=shape
103 )
104 )
James Ward24dbc422022-10-19 12:20:31 +0100105 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100106 f32_tensor = np.float32(
107 self.rng.uniform(
108 low=self.random_fp_low, high=self.random_fp_high, size=shape
109 )
110 )
James Ward24dbc422022-10-19 12:20:31 +0100111 # Floor the last 16 bits of each f32 value
112 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100113 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100114 return np.float32(
115 self.rng.uniform(
116 low=self.random_fp_low, high=self.random_fp_high, size=shape
117 )
118 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700119 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800120 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
Kevin Cheng989cb052021-04-28 16:29:44 -0700122 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700123 placeholders = []
124
Kevin Cheng989cb052021-04-28 16:29:44 -0700125 assert len(shape_list) == len(dtype_list)
126
127 for idx, shape in enumerate(shape_list):
128 arr = self.getRandTensor(shape, dtype_list[idx])
129 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700130
131 return placeholders
132
Kevin Cheng989cb052021-04-28 16:29:44 -0700133 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700134 consts = []
135
Kevin Cheng989cb052021-04-28 16:29:44 -0700136 assert len(shape_list) == len(dtype_list)
137
138 for idx, shape in enumerate(shape_list):
139 arr = self.getRandTensor(shape, dtype_list[idx])
140 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700141
142 return consts
143
144 def makeShape(self, rank):
145 if self.targetted_shape:
146 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800147 return np.int32(
148 self.rng.integers(
149 low=self.args.tensor_shape_range[0],
150 high=self.args.tensor_shape_range[1],
151 size=rank,
152 )
153 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155 def setTargetShape(self, shape):
156 self.targetted_shape = shape
157
158 def randInt(self, low=0, high=256):
159 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
160
161 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100162 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100163 return np.float32(
164 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
165 )
James Ward8b390432022-08-12 20:48:56 +0100166 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100167 return np.float16(
168 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
169 )
James Ward24dbc422022-10-19 12:20:31 +0100170 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100171 rand_f32 = np.float32(
172 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
173 )
James Ward24dbc422022-10-19 12:20:31 +0100174 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700175 elif dtype == DType.BOOL:
176 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700177 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700179 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700180 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100181 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 elif dtype == DType.INT16:
183 low, high = (-32768, 32768)
184 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800185 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800187 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 # Special size
189 return np.int64(self.rng.integers(low, high, size=1))[0]
190 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800191 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700192
193 return np.int32(self.rng.integers(low, high, size=1))[0]
194
195 def shapeStr(self, shape):
196
197 sStr = []
198 # Convert to strings
199 for i in shape:
200 sStr.append(str(i))
201
Kevin Cheng550ccc52021-03-03 11:21:43 -0800202 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700203
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100204 def typeStr(self, dtype):
205 if isinstance(dtype, list) or isinstance(dtype, tuple):
206 assert len(dtype) >= 2
207 strs = [self.typeStr(t) for t in dtype]
208 # Limit types to the first 2 as the 3rd is the accumulator
209 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700210 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100211 if dtype in DTYPE_ATTRIBUTES:
212 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700213 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 raise Exception(
215 "Unknown dtype, cannot convert to string: {}".format(dtype)
216 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100218 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100219 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100220 if dtype in DTYPE_ATTRIBUTES:
221 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100223 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700224
Luke Hutton57287132023-02-06 14:54:18 +0000225 def constrictBatchSize(self, shape):
226 # Limit the batch size unless an explicit target shape set
227 if self.args.max_batch_size and not self.args.target_shapes:
228 shape[0] = min(shape[0], self.args.max_batch_size)
229 return shape
230
James Ward30124a82023-02-02 14:56:33 +0000231 def makeDimension(self):
232 return self.randInt(
233 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
234 )
235
Eric Kunzee5e26762020-10-13 16:11:07 -0700236 # Argument generators
237 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
238 # Where the string descriptor is used to generate the test name and
239 # The build_fcn_arg_list is expanded and passed to the operator test
240 # build function
241
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100242 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
243 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
244
Matthew Haddon848efb42021-09-09 12:30:53 +0100245 # build_placeholder returns an int, ABS/other ops does not
246 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000247 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100248 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000249 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000250 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100251 return result_tens
252
253 # Ensure new output type has correct qinfo
254 if error_name == ErrorIf.WrongOutputType:
255 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000256 qinfo = [
257 TosaQuantGen.getZeroPoint(self, a.dtype),
258 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
259 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100260
261 # Invalidate Input/Output list for error if checks.
262 input_list = [a.name]
263 output_list = [result_tens.name]
264 pCount, cCount = op["operands"]
265 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000266 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
267 self, error_name, input_list, output_list
268 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100269
Les Bell729b0352021-11-24 10:28:21 +0000270 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100271 self.ser,
272 validator_fcns,
273 error_name,
274 op=op,
275 input_dtype=a.dtype,
276 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000277 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000278 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100279 input_list=input_list,
280 output_list=output_list,
281 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000282 ):
283 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100284
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000285 attr = None
286 if op["op"] == Op.NEGATE:
287 attr = ts.TosaSerializerAttribute()
288 attr.NegateAttribute(qinfo[0], qinfo[1])
289
290 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700291 return result_tens
292
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100293 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000294 result_tens = OutputShaper.binaryBroadcastOp(
295 self.ser, self.rng, a, b, error_name
296 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100297
298 # Invalidate Input/Output list for error if checks.
299 input_list = [a.name, b.name]
300 output_list = [result_tens.name]
301 pCount, cCount = op["operands"]
302 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000303 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
304 self, error_name, input_list, output_list
305 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100306
Les Bell729b0352021-11-24 10:28:21 +0000307 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100308 self.ser,
309 validator_fcns,
310 error_name,
311 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000312 input1=a,
313 input2=b,
314 input_dtype=a.dtype,
315 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000316 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100317 input_list=input_list,
318 output_list=output_list,
319 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000320 ):
321 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100322
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000323 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700324 return result_tens
325
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100326 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700327 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000328 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700329 return result_tens
330
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000331 def build_arithmetic_right_shift(
332 self, op, a, b, round, validator_fcns=None, error_name=None
333 ):
334 result_tens = OutputShaper.binaryBroadcastOp(
335 self.ser, self.rng, a, b, error_name
336 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100337
338 # Invalidate Input/Output list for error if checks.
339 input_list = [a.name, b.name]
340 output_list = [result_tens.name]
341 pCount, cCount = op["operands"]
342 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000343 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
344 self, error_name, input_list, output_list
345 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100346
Les Bell729b0352021-11-24 10:28:21 +0000347 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100348 self.ser,
349 validator_fcns,
350 error_name,
351 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000352 input1=a,
353 input2=b,
354 input_dtype=a.dtype,
355 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000356 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100357 input_list=input_list,
358 output_list=output_list,
359 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000360 ):
361 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800362
363 attr = ts.TosaSerializerAttribute()
364 attr.ArithmeticRightShiftAttribute(round)
365
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000366 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800367 return result_tens
368
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100369 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000370 result_tens = OutputShaper.binaryBroadcastOp(
371 self.ser, self.rng, a, b, error_name
372 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700373
374 # Special for multiply:
375 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100376 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700377 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100378 if error_name == ErrorIf.WrongOutputType:
379 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
380 outputDType = self.rng.choice(all_dtypes)
381 result_tens.setDtype(outputDType)
382
383 # Invalidate Input/Output list for error if checks.
384 input_list = [a.name, b.name]
385 output_list = [result_tens.name]
386 pCount, cCount = op["operands"]
387 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000388 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
389 self, error_name, input_list, output_list
390 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100391
Les Bell729b0352021-11-24 10:28:21 +0000392 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100393 self.ser,
394 validator_fcns,
395 error_name,
396 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000397 input1=a,
398 input2=b,
399 input_dtype=a.dtype,
400 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000401 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100402 input_list=input_list,
403 output_list=output_list,
404 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000405 ):
406 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700407
Kevin Chengaee1fac2020-11-11 13:54:06 -0800408 attr = ts.TosaSerializerAttribute()
409 attr.MulAttribute(shift)
410
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000411 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700412 return result_tens
413
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100414 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
415 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700416
Kevin Chengfe392ce2021-10-18 21:51:55 +0000417 attr = ts.TosaSerializerAttribute()
418 attr.TableAttribute(table)
419
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100420 # Invalidate Input/Output list for error if checks.
421 input_list = [a.name]
422 output_list = [result_tens.name]
423 pCount, cCount = op["operands"]
424 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000425 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
426 self, error_name, input_list, output_list
427 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100428
Les Bell729b0352021-11-24 10:28:21 +0000429 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100430 self.ser,
431 validator_fcns,
432 error_name,
433 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000434 input_shape=a.shape,
435 input_dtype=a.dtype,
436 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000437 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100438 input_list=input_list,
439 output_list=output_list,
440 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000441 ):
442 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000444 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700445
446 return result_tens
447
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100448 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
449 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
450
451 # Invalidate Input/Output list for error if checks.
452 input_list = [cond.name, a.name, b.name]
453 output_list = [result_tens.name]
454 pCount, cCount = op["operands"]
455 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000456 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
457 self, error_name, input_list, output_list
458 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100459
Les Bell729b0352021-11-24 10:28:21 +0000460 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100461 self.ser,
462 validator_fcns,
463 error_name,
464 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000465 input1=cond,
466 input2=a,
467 input3=b,
468 input_shape=a.shape,
469 input_dtype=a.dtype,
470 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000471 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100472 input_list=input_list,
473 output_list=output_list,
474 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000475 ):
476 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100477
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000478 self.ser.addOperator(
479 op["op"],
480 input_list,
481 output_list,
482 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700483 return result_tens
484
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100485 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000486 result_tens = OutputShaper.binaryComparisonOp(
487 self.ser, self.rng, a, b, error_name
488 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100489
490 # Invalidate Input/Output list for error if checks.
491 input_list = [a.name, b.name]
492 output_list = [result_tens.name]
493 pCount, cCount = op["operands"]
494 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000495 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
496 self, error_name, input_list, output_list
497 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100498
Les Bell729b0352021-11-24 10:28:21 +0000499 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100500 self.ser,
501 validator_fcns,
502 error_name,
503 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000504 input1=a,
505 input2=b,
506 input_shape=a.shape,
507 input_dtype=a.dtype,
508 output_shape=result_tens.shape,
509 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000510 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100511 input_list=input_list,
512 output_list=output_list,
513 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000514 ):
515 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100516
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000517 self.ser.addOperator(
518 op["op"],
519 input_list,
520 output_list,
521 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700522 return result_tens
523
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100524 def build_argmax(self, op, a, axis, validator_fcns, error_name):
525 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
526
527 # Invalidate Input/Output list for error if checks.
528 input_list = [a.name]
529 output_list = [result_tens.name]
530 pCount, cCount = op["operands"]
531 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000532 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
533 self, error_name, input_list, output_list
534 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100535
Les Bell729b0352021-11-24 10:28:21 +0000536 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100537 self.ser,
538 validator_fcns,
539 error_name,
540 op=op,
541 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000542 input_shape=a.shape,
543 input_dtype=a.dtype,
544 output_shape=result_tens.shape,
545 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000546 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100547 input_list=input_list,
548 output_list=output_list,
549 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000550 ):
551 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700552
553 attr = ts.TosaSerializerAttribute()
554 attr.AxisAttribute(axis)
555
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000556 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700557 return result_tens
558
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000559 def build_pool2d(
560 self,
561 op,
562 input,
James Ward8b390432022-08-12 20:48:56 +0100563 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000564 stride,
565 pad,
566 kernel,
567 validator_fcns=None,
568 error_name=None,
569 qinfo=None,
570 ):
571 result_tens = OutputShaper.pool2dOp(
572 self.ser, self.rng, input, kernel, stride, pad, error_name
573 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100574
575 # Ensure new output type has correct qinfo
576 if error_name == ErrorIf.WrongInputType:
577 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000578 qinfo = [
579 TosaQuantGen.getZeroPoint(self, input.dtype),
580 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
581 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100582
583 # Invalidate Input/Output list for error if checks.
584 input_list = [input.name]
585 output_list = [result_tens.name]
586 pCount, cCount = op["operands"]
587 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000588 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
589 self, error_name, input_list, output_list
590 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100591
Les Bell729b0352021-11-24 10:28:21 +0000592 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100593 self.ser,
594 validator_fcns,
595 error_name,
596 op=op,
597 input_shape=input.shape,
598 input_dtype=input.dtype,
599 output_shape=result_tens.shape,
600 output_dtype=result_tens.dtype,
601 kernel=kernel,
602 stride=stride,
603 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000604 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000605 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100606 input_list=input_list,
607 output_list=output_list,
608 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000609 ):
610 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700611
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000612 if qinfo is None:
613 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700614
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000615 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100616 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000617
618 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700619 return result_tens
620
James Ward8b390432022-08-12 20:48:56 +0100621 def build_maxpool2d(
622 self,
623 op,
624 input,
625 stride,
626 pad,
627 kernel,
628 validator_fcns=None,
629 error_name=None,
630 qinfo=None,
631 ):
632 # Same as build_pool2d but manually sets accum_dtype value
633 # (maxpool has no accum_dtype)
634 return self.build_pool2d(
635 op,
636 input,
637 DType.UNKNOWN,
638 stride,
639 pad,
640 kernel,
641 validator_fcns,
642 error_name,
643 qinfo,
644 )
645
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000646 def build_conv2d(
647 self,
648 op,
649 ifm,
650 filter,
651 bias,
James Ward8b390432022-08-12 20:48:56 +0100652 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000653 strides,
654 padding,
655 dilations,
656 validator_fcns=None,
657 error_name=None,
658 qinfo=None,
659 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800660 assert len(padding) == 4
661 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100662 self.ser,
663 self.rng,
664 ifm,
665 filter,
666 accum_dtype,
667 strides,
668 padding,
669 dilations,
670 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000671 )
672
673 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000674 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
675 DType.INT8,
676 DType.UINT8,
677 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000678 qinfo = [
679 TosaQuantGen.getZeroPoint(self, ifm.dtype),
680 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
681 ]
Les Bell0e027d42021-11-09 14:42:14 +0000682
683 # Invalidate Input/Output list for error_if checks.
684 input_list = [ifm.name, filter.name, bias.name]
685 output_list = [result_tens.name]
686 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000687 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
688 self, error_name, input_list, output_list
689 )
Les Bell0e027d42021-11-09 14:42:14 +0000690
Les Bell729b0352021-11-24 10:28:21 +0000691 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000692 self.ser,
693 validator_fcns,
694 error_name,
695 op=op,
696 input_dtype=ifm.dtype,
697 weight_dtype=filter.dtype,
698 output_dtype=result_tens.dtype,
699 qinfo=qinfo,
700 input_list=input_list,
701 num_operands=num_operands,
702 output_list=output_list,
703 pad=padding,
704 stride=strides,
705 dilation=dilations,
706 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100707 weight_shape=filter.shape,
708 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000709 ):
710 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700711
712 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000713 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700714
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000715 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700716 return result_tens
717
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000718 def build_conv3d(
719 self,
720 op,
721 ifm,
722 filter,
723 bias,
James Ward8b390432022-08-12 20:48:56 +0100724 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000725 strides,
726 padding,
727 dilations,
728 validator_fcns=None,
729 error_name=None,
730 qinfo=None,
731 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700732 assert len(padding) == 6
733 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100734 self.ser,
735 self.rng,
736 ifm,
737 filter,
738 accum_dtype,
739 strides,
740 padding,
741 dilations,
742 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000743 )
744
745 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000746 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
747 DType.INT8,
748 DType.UINT8,
749 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000750 qinfo = [
751 TosaQuantGen.getZeroPoint(self, ifm.dtype),
752 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
753 ]
Les Bell0e027d42021-11-09 14:42:14 +0000754
755 # Invalidate Input/Output list for error_if checks.
756 input_list = [ifm.name, filter.name, bias.name]
757 output_list = [result_tens.name]
758 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000759 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
760 self, error_name, input_list, output_list
761 )
Les Bell0e027d42021-11-09 14:42:14 +0000762
Les Bell729b0352021-11-24 10:28:21 +0000763 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000764 self.ser,
765 validator_fcns,
766 error_name,
767 op=op,
768 input_dtype=ifm.dtype,
769 weight_dtype=filter.dtype,
770 output_dtype=result_tens.dtype,
771 qinfo=qinfo,
772 input_list=input_list,
773 num_operands=num_operands,
774 output_list=output_list,
775 pad=padding,
776 stride=strides,
777 dilation=dilations,
778 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100779 weight_shape=filter.shape,
780 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000781 ):
782 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700783
784 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000785 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700786
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000787 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700788 return result_tens
789
Kevin Cheng550ccc52021-03-03 11:21:43 -0800790 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000791 self,
792 op,
793 ifm,
794 filter,
795 bias,
James Ward8b390432022-08-12 20:48:56 +0100796 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000797 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700798 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000799 output_shape,
800 validator_fcns=None,
801 error_name=None,
802 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800803 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700804 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000805 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100806 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000807 )
Les Bell0e027d42021-11-09 14:42:14 +0000808
809 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000810 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
811 DType.INT8,
812 DType.UINT8,
813 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000814 qinfo = [
815 TosaQuantGen.getZeroPoint(self, ifm.dtype),
816 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
817 ]
Les Bell0e027d42021-11-09 14:42:14 +0000818
819 # Invalidate Input/Output list for error_if checks.
820 input_list = [ifm.name, filter.name, bias.name]
821 output_list = [result_tens.name]
822 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000823 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
824 self, error_name, input_list, output_list
825 )
Les Bell0e027d42021-11-09 14:42:14 +0000826
Les Bell729b0352021-11-24 10:28:21 +0000827 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000828 self.ser,
829 validator_fcns,
830 error_name,
831 op=op,
832 input_dtype=ifm.dtype,
833 weight_dtype=filter.dtype,
834 output_dtype=result_tens.dtype,
835 qinfo=qinfo,
836 input_list=input_list,
837 num_operands=num_operands,
838 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700839 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000840 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000841 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100842 weight_shape=filter.shape,
843 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000844 ):
845 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700846
847 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000848 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700849
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000850 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700851 return result_tens
852
Kevin Cheng550ccc52021-03-03 11:21:43 -0800853 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000854 self,
855 op,
856 ifm,
857 filter,
858 bias,
James Ward8b390432022-08-12 20:48:56 +0100859 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000860 strides,
861 padding,
862 dilations,
863 validator_fcns=None,
864 error_name=None,
865 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800866 ):
867 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100868 self.ser,
869 self.rng,
870 ifm,
871 filter,
872 accum_dtype,
873 strides,
874 padding,
875 dilations,
876 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000877 )
878
879 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000880 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
881 DType.INT8,
882 DType.UINT8,
883 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000884 qinfo = [
885 TosaQuantGen.getZeroPoint(self, ifm.dtype),
886 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
887 ]
Les Bell0e027d42021-11-09 14:42:14 +0000888
889 # Invalidate Input/Output list for error_if checks.
890 input_list = [ifm.name, filter.name, bias.name]
891 output_list = [result_tens.name]
892 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000893 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
894 self, error_name, input_list, output_list
895 )
Les Bell0e027d42021-11-09 14:42:14 +0000896
Les Bell729b0352021-11-24 10:28:21 +0000897 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000898 self.ser,
899 validator_fcns,
900 error_name,
901 op=op,
902 input_dtype=ifm.dtype,
903 weight_dtype=filter.dtype,
904 output_dtype=result_tens.dtype,
905 qinfo=qinfo,
906 input_list=input_list,
907 num_operands=num_operands,
908 output_list=output_list,
909 pad=padding,
910 stride=strides,
911 dilation=dilations,
912 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100913 weight_shape=filter.shape,
914 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000915 ):
916 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700917
918 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000919 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700920
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000921 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700922 return result_tens
923
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000924 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100925 self,
926 op,
927 ifm,
928 filter,
929 bias,
930 accum_dtype,
931 validator_fcns=None,
932 error_name=None,
933 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000934 ):
935 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100936 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000937 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100938
939 # Invalidate Input/Output list for error if checks.
940 input_list = [ifm.name, filter.name, bias.name]
941 output_list = [result_tens.name]
942 pCount, cCount = op["operands"]
943 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000944 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
945 self, error_name, input_list, output_list
946 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100947
Les Bell729b0352021-11-24 10:28:21 +0000948 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100949 self.ser,
950 validator_fcns,
951 error_name,
952 op=op,
953 input_shape=ifm.shape,
954 input_dtype=ifm.dtype,
955 weight_dtype=filter.dtype,
956 output_shape=result_tens.shape,
957 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000958 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000959 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100960 input_list=input_list,
961 output_list=output_list,
962 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100963 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000964 ):
965 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700966
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000967 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000968 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000969
970 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700971 return result_tens
972
James Ward8b390432022-08-12 20:48:56 +0100973 def build_matmul(
974 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
975 ):
976 result_tens = OutputShaper.matmulOp(
977 self.ser, self.rng, a, b, accum_dtype, error_name
978 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100979
980 # Invalidate Input/Output list for error if checks.
981 input_list = [a.name, b.name]
982 output_list = [result_tens.name]
983 pCount, cCount = op["operands"]
984 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000985 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
986 self, error_name, input_list, output_list
987 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100988
Les Bell729b0352021-11-24 10:28:21 +0000989 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100990 self.ser,
991 validator_fcns,
992 error_name,
993 op=op,
994 input_shape=a.shape,
995 input_dtype=a.dtype,
996 input2_shape=b.shape,
997 input2_dtype=b.dtype,
998 output_shape=result_tens.shape,
999 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001000 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001001 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001002 input_list=input_list,
1003 output_list=output_list,
1004 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001005 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001006 ):
1007 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001008
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001009 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001010 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001011
1012 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001013 return result_tens
1014
Matthew Haddond6ce7252021-09-29 15:35:44 +01001015 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1016 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1017
1018 # Invalidate Input/Output list for error if checks.
1019 input_list = [a.name]
1020 output_list = [result_tens.name]
1021 pCount, cCount = op["operands"]
1022 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001023 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1024 self, error_name, input_list, output_list
1025 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001026
Les Bell729b0352021-11-24 10:28:21 +00001027 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001028 self.ser,
1029 validator_fcns,
1030 error_name,
1031 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001032 axis=axis,
1033 input_shape=a.shape,
1034 output_shape=result_tens.shape,
1035 input_dtype=a.dtype,
1036 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001037 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001038 input_list=input_list,
1039 output_list=output_list,
1040 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001041 ):
1042 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001043
1044 attr = ts.TosaSerializerAttribute()
1045 attr.AxisAttribute(axis)
1046
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001047 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001048 return result_tens
1049
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001050 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1051 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001052
Jeremy Johnson18e26662021-07-22 16:15:29 +01001053 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001054
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001055 if error_name == ErrorIf.MaxSmallerMin:
1056 # Make sure the numbers are different to invoke this error
1057 while v[0] == v[1]:
1058 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1059 max_val = min(v)
1060 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001061 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001062 max_val = max(v)
1063 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001064
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001065 # Invalidate Input/Output list for error if checks.
1066 input_list = [a.name]
1067 output_list = [result_tens.name]
1068 pCount, cCount = op["operands"]
1069 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001070 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1071 self, error_name, input_list, output_list
1072 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001073
Les Bell729b0352021-11-24 10:28:21 +00001074 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001075 self.ser,
1076 validator_fcns,
1077 error_name,
1078 op=op,
1079 max_val=max_val,
1080 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001081 input_shape=a.shape,
1082 output_shape=result_tens.shape,
1083 input_dtype=a.dtype,
1084 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001085 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001086 input_list=input_list,
1087 output_list=output_list,
1088 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001089 ):
1090 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001091
1092 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001093 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1094 if a.dtype == DType.FP16:
1095 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1096 min_val = min_val.astype(np.float32)
1097 max_val = max_val.astype(np.float32)
1098
1099 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001100 else:
James Ward34071252022-12-07 15:48:47 +00001101 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001102
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001103 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001104 return result_tens
1105
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001106 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1107 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001108 attr = ts.TosaSerializerAttribute()
1109
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001110 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001112 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001113 return result_tens
1114
1115 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001116 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1117 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001118
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001119 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001120 return result_tens
1121
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001122 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1123 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1124
1125 # Invalidate Input/Output list for error if checks.
1126 input_list = [a.name]
1127 output_list = [result_tens.name]
1128 pCount, cCount = op["operands"]
1129 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001130 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1131 self, error_name, input_list, output_list
1132 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001133
Les Bell729b0352021-11-24 10:28:21 +00001134 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001135 self.ser,
1136 validator_fcns,
1137 error_name,
1138 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001139 input_shape=a.shape,
1140 output_shape=result_tens.shape,
1141 input_dtype=a.dtype,
1142 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001143 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001144 input_list=input_list,
1145 output_list=output_list,
1146 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001147 ):
1148 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001149
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001150 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001151 return result_tens
1152
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001153 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1154 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1155
1156 # Invalidate Input/Output list for error if checks.
1157 input_list = [a.name]
1158 output_list = [result_tens.name]
1159 pCount, cCount = op["operands"]
1160 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001161 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1162 self, error_name, input_list, output_list
1163 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001164
Les Bell729b0352021-11-24 10:28:21 +00001165 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001166 self.ser,
1167 validator_fcns,
1168 error_name,
1169 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001170 input_shape=a.shape,
1171 output_shape=result_tens.shape,
1172 input_dtype=a.dtype,
1173 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001174 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001175 input_list=input_list,
1176 output_list=output_list,
1177 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001178 ):
1179 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001180
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001181 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001182 return result_tens
1183
Won Jeon78155c62023-06-10 00:20:04 +00001184 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1185 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1186
1187 # Invalidate Input/Output list for error if checks.
1188 input_list = [a.name]
1189 output_list = [result_tens.name]
1190 pCount, cCount = op["operands"]
1191 num_operands = pCount + cCount
1192 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1193 self, error_name, input_list, output_list
1194 )
1195
1196 if not TosaErrorValidator.evValidateErrorIfs(
1197 self.ser,
1198 validator_fcns,
1199 error_name,
1200 op=op,
1201 input_shape=a.shape,
1202 output_shape=result_tens.shape,
1203 input_dtype=a.dtype,
1204 output_dtype=result_tens.dtype,
1205 result_tensors=[result_tens],
1206 input_list=input_list,
1207 output_list=output_list,
1208 num_operands=num_operands,
1209 ):
1210 return None
1211
1212 self.ser.addOperator(op["op"], input_list, output_list)
1213 return result_tens
1214
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001215 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1216 if error_name != ErrorIf.WrongInputType:
1217 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001218
1219 # To store variable length list of input tensors we need to store axis along with it
1220 axis = a[-1]
1221 a = a[:-1]
1222
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001223 result_tens = OutputShaper.concatOp(
1224 self.ser, self.rng, axis, *a, error_name=error_name
1225 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001226
Matthew Haddon818ab902021-07-27 09:12:49 +01001227 input_tensor_names = []
1228 for tensor in a:
1229 input_tensor_names.append(tensor.name)
1230
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001231 # Invalidate Input/Output list for error if checks.
1232 input_list = input_tensor_names
1233 output_list = [result_tens.name]
1234 pCount, cCount = op["operands"]
1235 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001236 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1237 self, error_name, input_list, output_list
1238 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001239
Les Bell729b0352021-11-24 10:28:21 +00001240 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001241 self.ser,
1242 validator_fcns,
1243 error_name,
1244 op=op,
1245 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001246 input_shape=a[0].shape,
1247 output_shape=result_tens.shape,
1248 input_dtype=a[0].dtype,
1249 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001250 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001251 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001252 input_list=input_list,
1253 output_list=output_list,
1254 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001255 ):
1256 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001257
1258 attr = ts.TosaSerializerAttribute()
1259 attr.AxisAttribute(axis)
1260
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001261 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001262 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001263
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 def build_pad(
1265 self,
1266 op,
1267 a,
1268 padding,
1269 pad_const_int,
1270 pad_const_float,
1271 validator_fcns=None,
1272 error_name=None,
1273 qinfo=None,
1274 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001275 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001276
Kevin Chengfe392ce2021-10-18 21:51:55 +00001277 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001278 attr.PadAttribute(
1279 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1280 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001281
Matthew Haddone807aae2021-10-11 18:12:58 +01001282 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001283 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001284 output_list = [result_tens.name]
1285 pCount, cCount = op["operands"]
1286 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001287 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1288 self, error_name, input_list, output_list
1289 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001290
Les Bell729b0352021-11-24 10:28:21 +00001291 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001292 self.ser,
1293 validator_fcns,
1294 error_name,
1295 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001296 input_shape=a.shape,
1297 output_shape=result_tens.shape,
1298 input_dtype=a.dtype,
1299 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001300 pad=padding,
1301 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001302 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001303 input_list=input_list,
1304 output_list=output_list,
1305 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001306 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001307 ):
1308 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001309
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001310 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001311 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001312
Matthew Haddone807aae2021-10-11 18:12:58 +01001313 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001314 result_tens = OutputShaper.reshapeOp(
1315 self.ser, self.rng, a, newShape, error_name
1316 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001317
1318 # Invalidate Input/Output list for error if checks.
1319 input_list = [a.name]
1320 output_list = [result_tens.name]
1321 pCount, cCount = op["operands"]
1322 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001323 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1324 self, error_name, input_list, output_list
1325 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001326
Les Bell729b0352021-11-24 10:28:21 +00001327 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001328 self.ser,
1329 validator_fcns,
1330 error_name,
1331 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001332 input_shape=a.shape,
1333 output_shape=result_tens.shape,
1334 input_dtype=a.dtype,
1335 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001336 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001337 input_list=input_list,
1338 output_list=output_list,
1339 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001340 ):
1341 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001342
1343 attr = ts.TosaSerializerAttribute()
1344 attr.ReshapeAttribute(newShape)
1345
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001346 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001347 return result_tens
1348
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001349 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1350 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1351
1352 # Invalidate Input/Output list for error if checks.
1353 input_list = [a.name]
1354 output_list = [result_tens.name]
1355 pCount, cCount = op["operands"]
1356 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001357 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1358 self, error_name, input_list, output_list
1359 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001360
Les Bell729b0352021-11-24 10:28:21 +00001361 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001362 self.ser,
1363 validator_fcns,
1364 error_name,
1365 op=op,
1366 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001367 input_shape=a.shape,
1368 output_shape=result_tens.shape,
1369 input_dtype=a.dtype,
1370 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001371 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001372 input_list=input_list,
1373 output_list=output_list,
1374 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001375 ):
1376 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001377
1378 attr = ts.TosaSerializerAttribute()
1379 attr.AxisAttribute(axis)
1380
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001381 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001382 return result_tens
1383
Matthew Haddone807aae2021-10-11 18:12:58 +01001384 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1385 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001386
Kevin Chengfe392ce2021-10-18 21:51:55 +00001387 attr = ts.TosaSerializerAttribute()
1388 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001389
Matthew Haddone807aae2021-10-11 18:12:58 +01001390 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001391 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001392 output_list = [result_tens.name]
1393 pCount, cCount = op["operands"]
1394 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001395 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1396 self, error_name, input_list, output_list
1397 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001398
Les Bell729b0352021-11-24 10:28:21 +00001399 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001400 self.ser,
1401 validator_fcns,
1402 error_name,
1403 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 input_shape=a.shape,
1405 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001406 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001407 input_dtype=a.dtype,
1408 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001409 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001410 input_list=input_list,
1411 output_list=output_list,
1412 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001413 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001414 ):
1415 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001416
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001418 return result_tens
1419
Matthew Haddone807aae2021-10-11 18:12:58 +01001420 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001421 result_tens = OutputShaper.sliceOp(
1422 self.ser, self.rng, a, start, size, error_name
1423 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001424
1425 # Invalidate Input/Output list for error if checks.
1426 input_list = [a.name]
1427 output_list = [result_tens.name]
1428 pCount, cCount = op["operands"]
1429 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001430 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1431 self, error_name, input_list, output_list
1432 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001433
Les Bell729b0352021-11-24 10:28:21 +00001434 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001435 self.ser,
1436 validator_fcns,
1437 error_name,
1438 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001439 input_shape=a.shape,
1440 output_shape=result_tens.shape,
1441 input_dtype=a.dtype,
1442 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001443 start=start,
1444 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001445 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001446 input_list=input_list,
1447 output_list=output_list,
1448 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001449 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001450 ):
1451 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001452
1453 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001454 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001455
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001456 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001457 return result_tens
1458
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001459 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1460 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1461
1462 # Invalidate Input/Output list for error if checks.
1463 input_list = [a.name]
1464 output_list = [result_tens.name]
1465 pCount, cCount = op["operands"]
1466 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001467 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1468 self, error_name, input_list, output_list
1469 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001470
Les Bell729b0352021-11-24 10:28:21 +00001471 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001472 self.ser,
1473 validator_fcns,
1474 error_name,
1475 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001476 input_shape=a.shape,
1477 output_shape=result_tens.shape,
1478 input_dtype=a.dtype,
1479 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001480 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001481 input_list=input_list,
1482 output_list=output_list,
1483 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001484 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001485 ):
1486 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001487
1488 attr = ts.TosaSerializerAttribute()
1489 attr.TileAttribute(multiples)
1490
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001491 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001492 return result_tens
1493
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001494 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001495
1496 # Create a new indicies tensor
1497 # here with data that doesn't exceed the dimensions of the values tensor
1498
Kevin Cheng550ccc52021-03-03 11:21:43 -08001499 K = values.shape[1] # K
1500 W = self.randInt(
1501 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1502 ) # W
1503 indicies_arr = np.int32(
1504 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1505 ) # (N, W)
1506 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001507
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001508 result_tens = OutputShaper.gatherOp(
1509 self.ser, self.rng, values, indicies, error_name
1510 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001511
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001512 # Invalidate Input/Output list for error if checks.
1513 input_list = [values.name, indicies.name]
1514 output_list = [result_tens.name]
1515 pCount, cCount = op["operands"]
1516 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001517 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1518 self, error_name, input_list, output_list
1519 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001520
Les Bell729b0352021-11-24 10:28:21 +00001521 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001522 self.ser,
1523 validator_fcns,
1524 error_name,
1525 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001526 input_shape=values.shape,
1527 output_shape=result_tens.shape,
1528 input_dtype=values.dtype,
1529 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001530 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001531 input_list=input_list,
1532 output_list=output_list,
1533 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001534 ):
1535 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001536
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001537 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001538
1539 return result_tens
1540
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001541 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001542
1543 # Create a new indicies tensor
1544 # here with data that doesn't exceed the dimensions of the values_in tensor
1545
Kevin Cheng550ccc52021-03-03 11:21:43 -08001546 K = values_in.shape[1] # K
1547 W = input.shape[1] # W
1548 indicies_arr = np.int32(
1549 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1550 ) # (N, W)
1551 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001552
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001553 result_tens = OutputShaper.scatterOp(
1554 self.ser, self.rng, values_in, indicies, input, error_name
1555 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001556
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001557 # Invalidate Input/Output list for error if checks.
1558 input_list = [values_in.name, indicies.name, input.name]
1559 output_list = [result_tens.name]
1560 pCount, cCount = op["operands"]
1561 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1563 self, error_name, input_list, output_list
1564 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001565
Les Bell729b0352021-11-24 10:28:21 +00001566 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001567 self.ser,
1568 validator_fcns,
1569 error_name,
1570 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001571 input_shape=values_in.shape,
1572 output_shape=result_tens.shape,
1573 input_dtype=values_in.dtype,
1574 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001575 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001576 input_list=input_list,
1577 output_list=output_list,
1578 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001579 ):
1580 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001581
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001582 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001583
Kevin Cheng77d0f762020-11-24 10:26:32 -08001584 return result_tens
1585
Kevin Cheng550ccc52021-03-03 11:21:43 -08001586 def build_resize(
1587 self,
1588 op,
1589 input,
1590 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001591 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001592 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001593 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001594 input_dtype,
1595 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001596 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001597 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001598 ):
1599 result_tens = OutputShaper.resizeOp(
1600 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001601 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001602 input,
1603 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001604 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001605 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001606 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001607 input_dtype,
1608 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001609 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001610 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001611
Matthew Haddon848efb42021-09-09 12:30:53 +01001612 # Invalidate Input/Output list for error if checks.
1613 input_list = [input.name]
1614 output_list = [result_tens.name]
1615 pCount, cCount = op["operands"]
1616 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001617 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1618 self, error_name, input_list, output_list
1619 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001620
Les Bell729b0352021-11-24 10:28:21 +00001621 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001622 self.ser,
1623 validator_fcns,
1624 error_name,
1625 op=op,
1626 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001627 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001628 input_dtype=input_dtype,
1629 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001630 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001631 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001632 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001633 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001634 input_list=input_list,
1635 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001636 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001637 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001638 ):
1639 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001640
Eric Kunzee5e26762020-10-13 16:11:07 -07001641 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001642
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001643 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001644
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001645 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001646 return result_tens
1647
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001648 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1649 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1650 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001651 self.ser.addOperator(
1652 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1653 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001654 return result_tens
1655
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001656 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001657 self.ser.addOutputTensor(val)
1658 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001659
1660 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001661 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001662 result_tens = OutputShaper.typeConversionOp(
1663 self.ser, self.rng, val, out_dtype, error_name
1664 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001665
1666 # Invalidate Input/Output list for error if checks.
1667 input_list = [val.name]
1668 output_list = [result_tens.name]
1669 pCount, cCount = op["operands"]
1670 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001671 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1672 self, error_name, input_list, output_list
1673 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001674
Les Bell729b0352021-11-24 10:28:21 +00001675 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001676 self.ser,
1677 validator_fcns,
1678 error_name,
1679 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001680 input_shape=val.shape,
1681 output_shape=result_tens.shape,
1682 input_dtype=val.dtype,
1683 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001684 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001685 input_list=input_list,
1686 output_list=output_list,
1687 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001688 ):
1689 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001690
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001691 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001692 return result_tens
1693
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001694 def build_rescale(
1695 self,
1696 op,
1697 val,
1698 out_dtype,
1699 scale32,
1700 double_round,
1701 per_channel,
1702 validator_fcns,
1703 error_name,
1704 ):
1705 result_tens = OutputShaper.typeConversionOp(
1706 self.ser, self.rng, val, out_dtype, error_name
1707 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
1709 if per_channel:
1710 nc = val.shape[-1]
1711 else:
1712 nc = 1
1713
1714 in_type_width = self.typeWidth(val.dtype)
1715 out_type_width = self.typeWidth(out_dtype)
1716
Kevin Cheng3a478572021-01-22 17:21:02 -08001717 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001718 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001719 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001720 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001721 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001722 in_type_width += 1
1723 elif error_name in [
1724 ErrorIf.InputZeroPointNotZero,
1725 ErrorIf.U16InputZeroPointNotValid,
1726 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001727 input_zp = self.randInt(-128, 128)
1728 if input_zp == 0:
1729 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001730 in_type_width += 1
1731 elif val.dtype == DType.UINT16:
1732 # Must come after ErrorIf.U16InputZeroPointNotValid check
1733 input_zp = self.rng.choice([0, 32768])
1734 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001735 else:
1736 input_zp = 0
1737
Kevin Cheng3a478572021-01-22 17:21:02 -08001738 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001739 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001740 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001741 elif out_dtype == DType.UINT8:
1742 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001743 out_type_width += 1
1744 elif error_name in [
1745 ErrorIf.OutputZeroPointNotZero,
1746 ErrorIf.U16OutputZeroPointNotValid,
1747 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001748 output_zp = self.randInt(-128, 128)
1749 if output_zp == 0:
1750 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001751 out_type_width += 1
1752 elif out_dtype == DType.UINT16:
1753 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1754 output_zp = self.rng.choice([0, 32768])
1755 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001756 else:
1757 output_zp = 0
1758
1759 # Calculate scale based on:
1760 # scale = a *(2^output_width)/(2^input_width))
1761
1762 a = np.float32(self.rng.random(size=[nc]))
1763 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1764
1765 if scale32:
1766 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001767 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001768 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1769 else:
1770 # Cap the scaling at 2^15 - 1 for scale16
1771 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1772
Kevin Cheng550ccc52021-03-03 11:21:43 -08001773 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001774
1775 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1776 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001777 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1778 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001779
1780 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001781 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1782 scale_arr[i], scale32
1783 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001784 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1785 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001786
Kevin Cheng550ccc52021-03-03 11:21:43 -08001787 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001788 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001789 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001790 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001791 assert val.placeholderFilename
1792 values = np.load(
1793 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1794 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001795 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1796 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1797 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1798 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001799 if not np.all(np.array_equal(values, val_adj)):
1800 # Values changed so overwrite file with new values
1801 np.save(
1802 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1803 val_adj,
1804 False,
1805 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001806
Matthew Haddonc2025212021-10-08 21:21:05 +01001807 # Invalidate Input/Output list for error if checks.
1808 input_list = [val.name]
1809 output_list = [result_tens.name]
1810 pCount, cCount = op["operands"]
1811 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001812 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1813 self, error_name, input_list, output_list
1814 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001815
1816 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001817 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001818 self.ser,
1819 validator_fcns,
1820 error_name,
1821 op=op,
1822 input_dtype=val.dtype,
1823 output_dtype=out_dtype,
1824 input_shape=val.shape,
1825 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 scale32=scale32,
1827 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001828 input_list=input_list,
1829 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001830 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001831 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001832 ):
1833 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001834
Eric Kunzee5e26762020-10-13 16:11:07 -07001835 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001836 attr.RescaleAttribute(
1837 input_zp,
1838 output_zp,
1839 multiplier_arr,
1840 shift_arr,
1841 scale32,
1842 double_round,
1843 per_channel,
1844 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001845
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001846 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001847 return result_tens
1848
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001849 def _get_condition_tensor(self, op, cond, error_name):
1850 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1851 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1852 else:
1853 cond_type = DType.BOOL
1854 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1855 choice = self.rng.choice([1, 2])
1856 if choice == 1:
1857 cond_shape = [2]
1858 else:
1859 cond_shape = [1, 2]
1860 else:
1861 # Must be of size 1 (rank 0)
1862 cond_shape = []
1863 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1864 return cond_tens
1865
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001866 def build_cond_if_const(
1867 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1868 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001869 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001870 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001871 # and fill them with const nodes for the body.
1872
1873 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001874 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001875
1876 # Make then/else tensors
1877 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001878
1879 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001880 if error_name in [
1881 ErrorIf.CondIfOutputListThenGraphMismatch,
1882 ErrorIf.CondIfOutputListElseGraphMismatch,
1883 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001884 incorrect_shape = deepcopy(then_tens.shape)
1885 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001886 incorrect_shape[i] += (
1887 self.rng.choice([-3, -2, 2, 3])
1888 if incorrect_shape[i] > 3
1889 else self.rng.choice([1, 2, 4])
1890 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001891 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1892
Jeremy Johnson18e26662021-07-22 16:15:29 +01001893 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1894 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001895
1896 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001897 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001898
1899 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001900 then_block = "THEN_BLOCK"
1901 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001902 attr = ts.TosaSerializerAttribute()
1903 attr.CondIfAttribute(then_block, else_block)
1904
1905 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001906 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001907
Jerry Ge9e94af82022-10-27 09:57:00 -07001908 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001909 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001910 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1911 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1912 else:
1913 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001914 self.ser.addOutputTensor(then_tens)
1915
Jerry Ge9e94af82022-10-27 09:57:00 -07001916 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001917 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1918 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1919 else:
1920 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001921 self.ser.addOutputTensor(else_tens)
1922
Les Bell729b0352021-11-24 10:28:21 +00001923 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001924 self.ser,
1925 validator_fcns,
1926 error_name,
1927 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001928 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001929 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001930 ):
1931 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001932
Eric Kunzee5e26762020-10-13 16:11:07 -07001933 return result_tens
1934
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001935 def build_cond_if_binary(
1936 self, op, a, b, cond, validator_fcns=None, error_name=None
1937 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001938 # For cond_if with a binary op in the then/else blocks, take a and b and
1939 # alternately add or subtract them based on the condition
1940
1941 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001942 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001945
1946 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001947 then_block = "THEN_BLOCK"
1948 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001949 attr = ts.TosaSerializerAttribute()
1950 attr.CondIfAttribute(then_block, else_block)
1951
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001952 if error_name in [
1953 ErrorIf.CondIfInputListThenGraphMismatch,
1954 ErrorIf.CondIfInputListElseGraphMismatch,
1955 ErrorIf.CondIfOutputListElseGraphMismatch,
1956 ErrorIf.CondIfOutputListThenGraphMismatch,
1957 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001958 incorrect_shape = a.shape.copy()
1959 for i in range(len(incorrect_shape)):
1960 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1961 incorrect_block_input = deepcopy(a)
1962 incorrect_block_input.shape = incorrect_shape
1963
Eric Kunzee5e26762020-10-13 16:11:07 -07001964 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001965 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001966 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001967 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001968
James Ward24dbc422022-10-19 12:20:31 +01001969 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001970 then_op, else_op = Op.ADD, Op.SUB
1971 elif a.dtype in (DType.INT8, DType.INT16):
1972 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1973 else:
1974 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001975
Les Bell6040b4d2021-10-11 12:50:31 +01001976 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001977 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001978 if (
1979 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1980 and block == then_block
1981 ) or (
1982 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1983 and block == else_block
1984 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001985 self.ser.addInputTensor(incorrect_block_input)
1986 self.ser.addInputTensor(b)
1987 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001988 elif (
1989 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1990 and block == then_block
1991 ) or (
1992 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1993 and block == else_block
1994 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001995 self.ser.addInputTensor(a)
1996 self.ser.addInputTensor(b)
1997 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1998 else:
1999 self.ser.addInputTensor(a)
2000 self.ser.addInputTensor(b)
2001 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002002 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002003
Les Bell729b0352021-11-24 10:28:21 +00002004 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002005 self.ser,
2006 validator_fcns,
2007 error_name,
2008 op=op,
2009 a=a,
2010 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002011 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002012 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002013 ):
2014 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002015
Eric Kunzee5e26762020-10-13 16:11:07 -07002016 return result_tens
2017
Matthew Haddon630c17c2021-10-14 15:05:41 +01002018 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002019 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002020
Kevin Cheng550ccc52021-03-03 11:21:43 -08002021 cond_block = "COND_BLOCK"
2022 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002023
2024 attr = ts.TosaSerializerAttribute()
2025 attr.WhileLoopAttribute(cond_block, body_block)
2026
2027 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002028 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002029 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002030 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002031
2032 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002033 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2034 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002035 if error_name == ErrorIf.InputListOutputListMismatch:
2036 incorrect_acc = deepcopy(acc)
2037 for i in range(len(incorrect_acc.shape)):
2038 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2039 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2040 else:
2041 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002042
2043 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002044 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002045 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002046 [iter.name, a.name, acc.name],
2047 [iter_out.name, a_out.name, acc_out.name],
2048 attr,
2049 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002050 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002051
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002052 if error_name in [
2053 ErrorIf.InputListCondGraphMismatch,
2054 ErrorIf.InputListBodyGraphInputMismatch,
2055 ErrorIf.InputListBodyGraphOutputMismatch,
2056 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002057 incorrect_iter = deepcopy(iter)
2058 for i in range(len(incorrect_iter.shape)):
2059 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2060 if len(incorrect_iter.shape) == 0:
2061 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2062
2063 incorrect_acc = deepcopy(acc)
2064 for i in range(len(incorrect_acc.shape)):
2065 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2066
Eric Kunzee5e26762020-10-13 16:11:07 -07002067 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002068 self.ser.addBasicBlock(cond_block)
2069
Matthew Haddon630c17c2021-10-14 15:05:41 +01002070 if error_name == ErrorIf.InputListCondGraphMismatch:
2071 self.ser.addInputTensor(incorrect_iter)
2072 self.ser.addInputTensor(a)
2073 self.ser.addInputTensor(incorrect_acc)
2074 else:
2075 self.ser.addInputTensor(iter)
2076 self.ser.addInputTensor(a)
2077 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002078 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002079
2080 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002081 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002082 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002083 cond_type = DType.BOOL
2084 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2085 choice = self.rng.choice([1, 2])
2086 if choice == 1:
2087 cond_shape = [3]
2088 else:
2089 cond_shape = [1, 2]
2090 else:
2091 cond_shape = []
2092 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002093
Kevin Cheng550ccc52021-03-03 11:21:43 -08002094 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002095
2096 # BODY block (input: a, acc, iter, output: a, acc, iter)
2097 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002098 self.ser.addBasicBlock(body_block)
2099
Matthew Haddon630c17c2021-10-14 15:05:41 +01002100 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2101 self.ser.addInputTensor(incorrect_iter)
2102 self.ser.addInputTensor(a)
2103 self.ser.addInputTensor(incorrect_acc)
2104 else:
2105 self.ser.addInputTensor(iter)
2106 self.ser.addInputTensor(a)
2107 self.ser.addInputTensor(acc)
2108
Kevin Cheng550ccc52021-03-03 11:21:43 -08002109 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002110
2111 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002112 iter_body_out = self.ser.addIntermediate(
2113 incorrect_iter.shape, incorrect_iter.dtype
2114 )
2115 acc_body_out = self.ser.addIntermediate(
2116 incorrect_acc.shape, incorrect_acc.dtype
2117 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002118 else:
2119 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2120 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2121
Eric Kunzee5e26762020-10-13 16:11:07 -07002122 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2123 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2124 self.ser.addOutputTensor(iter_body_out)
2125 self.ser.addOutputTensor(a)
2126 self.ser.addOutputTensor(acc_body_out)
2127
Les Bell729b0352021-11-24 10:28:21 +00002128 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002129 self.ser,
2130 validator_fcns,
2131 error_name,
2132 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002133 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002134 ):
2135 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002136
Eric Kunzee5e26762020-10-13 16:11:07 -07002137 return acc_out
2138
Luke Hutton57287132023-02-06 14:54:18 +00002139 def build_fft2d(
2140 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2141 ):
2142 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2143
2144 input_names = [val1.name, val2.name]
2145 pCount, cCount = op["operands"]
2146 num_operands = pCount + cCount
2147
2148 output_names = [res.name for res in results]
2149 output_shapes = [res.shape for res in results]
2150 output_dtypes = [res.dtype for res in results]
2151
2152 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2153 self, error_name, input_names, output_names
2154 )
2155
2156 if not TosaErrorValidator.evValidateErrorIfs(
2157 self.ser,
2158 validator_fcns,
2159 error_name,
2160 op=op,
2161 inverse=inverse,
2162 input1=val1,
2163 input2=val2,
2164 input_shape=val1.shape,
2165 input_dtype=val1.dtype,
2166 output_shape=output_shapes,
2167 output_dtype=output_dtypes,
2168 result_tensors=results,
2169 input_list=input_names,
2170 output_list=output_names,
2171 num_operands=num_operands,
2172 ):
2173 return None
2174
2175 attr = ts.TosaSerializerAttribute()
2176 attr.FFTAttribute(inverse)
2177
2178 self.ser.addOperator(op["op"], input_names, output_names, attr)
2179 return results
2180
Luke Hutton261b7b62023-01-10 14:50:31 +00002181 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2182 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2183
2184 input_names = [val.name]
2185 pCount, cCount = op["operands"]
2186 num_operands = pCount + cCount
2187
2188 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002189 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002190 output_dtypes = [res.dtype for res in results]
2191
2192 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2193 self, error_name, input_names, output_names
2194 )
2195
2196 if not TosaErrorValidator.evValidateErrorIfs(
2197 self.ser,
2198 validator_fcns,
2199 error_name,
2200 op=op,
2201 input_shape=val.shape,
2202 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002203 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002204 output_dtype=output_dtypes,
2205 result_tensors=results,
2206 input_list=input_names,
2207 output_list=output_names,
2208 num_operands=num_operands,
2209 ):
2210 return None
2211
2212 self.ser.addOperator(op["op"], input_names, output_names)
2213 return results
2214
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002215 def create_filter_lists(
2216 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2217 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002218 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2219 default_test_rank_range = range(1, 5)
2220 if not shapeFilter:
2221 shapeFilter = [None]
2222
2223 # Calculate the filters based on what is requested and what the operator allows
2224 rmin, rmax = op["rank"]
2225 if rankFilter is not None:
2226 cleanRankFilter = []
2227 # Ensure rankFilter values are allowed by operator
2228 for rank in rankFilter:
2229 if rank >= rmin and rank <= rmax:
2230 cleanRankFilter.append(rank)
2231 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002232 # Ensure default behaviour is bounded by default range or by operator,
2233 # whichever is the smaller range of ranks.
2234 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002235 cleanRankFilter = (
2236 opRankRange
2237 if len(opRankRange) <= len(default_test_rank_range)
2238 else default_test_rank_range
2239 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002240 else:
2241 cleanRankFilter = range(rmin, rmax + 1)
2242
2243 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002244
Matthew Haddon1c00b712021-10-01 15:51:03 +01002245 if dtypeFilter is not None:
2246 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002247 # Create list of operator dtypes filtered by requested dtypes
2248 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002249 if dtype in dtypeFilter or (
2250 isinstance(dtype, list) and dtype[0] in dtypeFilter
2251 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002252 cleanDtypeFilter.append(dtype)
2253 else:
2254 cleanDtypeFilter = dtypes
2255
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002256 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002257 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002258 "shapeFilter": shapeFilter,
2259 "rankFilter": cleanRankFilter,
2260 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002261 }
2262 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002263 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002264 if validator is not None:
2265 validator_info = validator(check=False, op=op)
2266 else:
2267 return None
2268
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002269 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002270
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002271 # Set parameters as required
2272 if error_arguments["rank"] is not None:
2273 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002274 else:
2275 rankFilter = cleanRankFilter
2276
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002277 if error_arguments["dtype"] is not None:
2278 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002279 else:
2280 dtypeFilter = cleanDtypeFilter
2281
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002282 if error_arguments["shape"] is not None:
2283 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002284 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002285 shapeFilter = shapeFilter[
2286 :2
2287 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002288
2289 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002290 "shapeFilter": shapeFilter,
2291 "rankFilter": rankFilter,
2292 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002293 }
2294 return filterDict
2295
Kevin Cheng550ccc52021-03-03 11:21:43 -08002296 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002297 self,
2298 opName,
2299 shapeFilter=[None],
2300 rankFilter=None,
2301 dtypeFilter=None,
2302 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002303 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002304
2305 try:
2306 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002307 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002308 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002309
2310 # Initialize a new random number generator
2311 self.rng = np.random.default_rng(self.random_seed)
2312
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002313 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002314
Eric Kunzee5e26762020-10-13 16:11:07 -07002315 # Test list consists of a tuple of:
2316 # (opName, testNameStr, dtype, shapeList, argumentsList)
2317 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002318 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002319 error_if_validators = op["error_if_validators"]
2320 else:
2321 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002322
Matthew Haddon1c00b712021-10-01 15:51:03 +01002323 for validator in error_if_validators:
2324 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002325 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002326 else:
2327 error_name = None
2328
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002329 filterDict = self.create_filter_lists(
2330 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2331 )
2332 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002333 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002334 cleanRankFilter = filterDict["rankFilter"]
2335 cleanDtypeFilter = filterDict["dtypeFilter"]
2336 cleanShapeFilter = filterDict["shapeFilter"]
2337 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002338
2339 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002340 for t in cleanDtypeFilter:
2341 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002342 # Filter out by rank
2343 if shape is not None and len(shape) != r:
2344 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002345 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002346 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002347
Matthew Haddon74567092021-07-16 15:38:20 +01002348 shapeStr = self.shapeStr(shapeList[0])
2349 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002350
Matthew Haddon74567092021-07-16 15:38:20 +01002351 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2352 argList = []
2353 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002354 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002355 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002356 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002357
Matthew Haddon74567092021-07-16 15:38:20 +01002358 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002359 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002360 if argStr:
2361 testStr = "{}_{}_{}_{}".format(
2362 opName, shapeStr, typeStr, argStr
2363 )
2364 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002365 testStr = "{}_{}_{}".format(
2366 opName, shapeStr, typeStr
2367 )
2368 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002369 if argStr:
2370 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2371 opName, error_name, shapeStr, typeStr, argStr
2372 )
2373 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002374 testStr = "{}_ERRORIF_{}_{}_{}".format(
2375 opName, error_name, shapeStr, typeStr
2376 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002377
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002378 testList.append(
2379 (opName, testStr, t, error_name, shapeList, args)
2380 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002381
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002382 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002383 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2384 if "invalid_test_validators" in op:
2385 invalid_test_validators = op["invalid_test_validators"]
2386 clean_testList = []
2387 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002388 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002389 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002390 if validator_fcn(
2391 opName=test[0],
2392 input_dtype=test[2],
2393 shapeList=test[4],
2394 args=test[5],
2395 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002396 remove_test = True
2397 if not remove_test:
2398 clean_testList.append(test)
2399 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002400
2401 return testList
2402
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002403 def serializeTest(
2404 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2405 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002406 try:
2407 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002408 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002409 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002410
Jeremy Johnson0c716862023-04-13 17:18:19 +01002411 if self.args.verbose:
2412 print(f"Creating {testStr}")
2413
Eric Kunzee5e26762020-10-13 16:11:07 -07002414 # Create a serializer
2415 self.createSerializer(opName, testStr)
2416
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002417 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002418 if "error_if_validators" in op:
2419 error_if_validators = op["error_if_validators"]
2420 else:
2421 error_if_validators = None
2422
Kevin Cheng550ccc52021-03-03 11:21:43 -08002423 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002424 num_operands = pCount + cCount
2425
2426 if isinstance(dtype_or_dtypeList, list):
2427 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002428 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002429 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002430 else:
2431 dtypeList = [dtype_or_dtypeList] * (num_operands)
2432
Kevin Cheng93a16282021-08-31 16:14:03 -07002433 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002434 assert (
2435 len(shapeList) == num_operands
2436 ), "shapeList length {} must match number of operands {}".format(
2437 len(shapeList), num_operands
2438 )
2439 assert (
2440 len(dtypeList) == num_operands
2441 ), "dtypeList length {} must match number of operands {}".format(
2442 len(dtypeList), num_operands
2443 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002444
2445 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002446 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002447 except KeyError:
2448 qgen = None
2449
2450 # Build the random tensor operands and the test
2451 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002452
Matthew Haddon1c00b712021-10-01 15:51:03 +01002453 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002454 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002455 else:
2456 qinfo = None
2457
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002458 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002459
Matthew Haddon1c00b712021-10-01 15:51:03 +01002460 try:
2461 if error_if_validators is None:
2462 if qinfo is not None:
2463 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2464 else:
2465 resultName = build_fcn(self, op, *tens, *testArgs)
2466 else:
2467 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002468 resultName = build_fcn(
2469 self,
2470 op,
2471 *tens,
2472 *testArgs,
2473 validator_fcns=error_if_validators,
2474 error_name=error_name,
2475 qinfo=qinfo,
2476 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002477 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002478 resultName = build_fcn(
2479 self,
2480 op,
2481 *tens,
2482 *testArgs,
2483 validator_fcns=error_if_validators,
2484 error_name=error_name,
2485 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002486 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002487 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002488 raise e
2489
Les Bell729b0352021-11-24 10:28:21 +00002490 if resultName:
2491 # The test is valid, serialize it
2492 self.serialize("test")
2493 else:
2494 # The test is not valid
2495 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002496
Eric Kunzee5e26762020-10-13 16:11:07 -07002497 def createDynamicOpLists(self):
2498
Jeremy Johnson00423432022-09-12 17:27:37 +01002499 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2500 # Already created these lists (can occur when class is initialized more than once)
2501 return
2502
Eric Kunzee5e26762020-10-13 16:11:07 -07002503 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002504 if not self.args.level8k:
2505 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2506 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2507 else:
2508 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2509 KERNELS_2D = [[1, bigK], [bigK, 2]]
2510 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002511
Kevin Cheng1533b852021-09-01 12:51:58 -07002512 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002513 testName = "conv2d_{}x{}".format(k[0], k[1])
2514 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2515 self.TOSA_OP_LIST[testName]["filter"] = k
2516 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002517
Kevin Cheng550ccc52021-03-03 11:21:43 -08002518 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2519 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2520 "depthwise_conv2d_TEMPLATE"
2521 ].copy()
2522 self.TOSA_OP_LIST[testName]["filter"] = k
2523 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002524
Kevin Cheng550ccc52021-03-03 11:21:43 -08002525 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2526 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2527 "transpose_conv2d_TEMPLATE"
2528 ].copy()
2529 self.TOSA_OP_LIST[testName]["filter"] = k
2530 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002531
Kevin Cheng1533b852021-09-01 12:51:58 -07002532 for k in KERNELS_3D:
2533 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2534 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2535 self.TOSA_OP_LIST[testName]["filter"] = k
2536 self.TOSA_OP_LIST[testName]["template"] = False
2537
Eric Kunzee5e26762020-10-13 16:11:07 -07002538 # Delete any templates after having created any dynamic ops
2539 # This is a two-pass operation because it's bad practice to delete
2540 # keys from dictionaries while iterating
2541 keyList = []
2542 for k in self.TOSA_OP_LIST:
2543 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002544 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002545 keyList.append(k)
2546 continue
2547 except KeyError:
2548 pass
2549
2550 for k in keyList:
2551 del self.TOSA_OP_LIST[k]
2552
2553 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002554 """Fill in default fields for ops if they aren't already specified.
2555 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002556 for op in self.TOSA_OP_LIST:
2557
2558 # Required fields
2559 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002560 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002561 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002562 raise Exception(
2563 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2564 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002565
2566 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002567 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002568 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002569 raise Exception(
2570 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2571 op
2572 )
2573 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002574
2575 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002576 _ = self.TOSA_OP_LIST[op]["types"]
2577 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002578 raise Exception(
2579 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2580 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002581
2582 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002583 _ = self.TOSA_OP_LIST[op]["op"]
2584 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002585 raise Exception(
2586 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2587 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002588
2589 # Put in default rank range, if missing
2590 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002591 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002592 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002593 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002594
2595 # Tensor operator list
2596 # 'op': op name
2597 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002598 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2599 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002600 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2601 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002602 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002603
Kevin Cheng550ccc52021-03-03 11:21:43 -08002604 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002605 TYPE_INT_FP = [
2606 DType.INT8,
2607 DType.INT16,
2608 DType.INT32,
2609 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002610 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002611 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002612 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002613
Kevin Cheng550ccc52021-03-03 11:21:43 -08002614 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002615 TYPE_FI32 = [
2616 DType.FP32,
2617 DType.FP16,
2618 DType.BF16,
2619 DType.INT32,
2620 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002621 TYPE_FIB = [
2622 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002623 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002624 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002625 DType.INT8,
2626 DType.INT16,
2627 DType.INT32,
2628 DType.BOOL,
2629 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002630 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002631
James Ward24dbc422022-10-19 12:20:31 +01002632 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002633
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002634 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002635 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002636 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002637 [DType.INT8, DType.INT8, DType.INT32],
2638 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002639 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002640 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002641 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002642 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002643 ]
2644
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002645 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002646
2647 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002648 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002649 "argmax": {
2650 "op": Op.ARGMAX,
2651 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002652 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002653 "build_fcn": (
2654 build_argmax,
2655 TosaTensorGen.tgBasic,
2656 TosaTensorValuesGen.tvgDefault,
2657 TosaArgGen.agAxis,
2658 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002659 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002660 "error_if_validators": (
2661 TosaErrorValidator.evAxisSmallerZero,
2662 TosaErrorValidator.evAxisLargerRank,
2663 TosaErrorValidator.evArgmaxOutputRankMismatch,
2664 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2665 TosaErrorValidator.evWrongRank,
2666 TosaErrorValidator.evWrongInputType,
2667 TosaErrorValidator.evWrongOutputType,
2668 TosaErrorValidator.evWrongInputList,
2669 TosaErrorValidator.evWrongOutputList,
2670 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002671 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002672 "avg_pool2d": {
2673 "op": Op.AVG_POOL2D,
2674 "operands": (1, 0),
2675 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002676 "build_fcn": (
2677 build_pool2d,
2678 TosaTensorGen.tgNHWC,
2679 TosaTensorValuesGen.tvgDefault,
2680 TosaArgGen.agPooling,
2681 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002682 "qgen": TosaQuantGen.qgUnary,
2683 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002684 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002685 "error_if_validators": (
2686 TosaErrorValidator.evKernelSmallerOne,
2687 TosaErrorValidator.evStrideSmallerOne,
2688 TosaErrorValidator.evPadSmallerZero,
2689 TosaErrorValidator.evWrongRank,
2690 TosaErrorValidator.evWrongInputType,
2691 TosaErrorValidator.evWrongOutputType,
2692 TosaErrorValidator.evWrongInputList,
2693 TosaErrorValidator.evWrongOutputList,
2694 TosaErrorValidator.evInputZeroPointNotZero,
2695 TosaErrorValidator.evOutputZeroPointNotZero,
2696 TosaErrorValidator.evPadLargerEqualKernel,
2697 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002698 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002699 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002700 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002701 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002702 "conv2d_TEMPLATE": {
2703 "op": Op.CONV2D,
2704 "operands": (1, 2),
2705 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002706 "build_fcn": (
2707 build_conv2d,
2708 TosaTensorGen.tgConv2D,
2709 TosaTensorValuesGen.tvgDefault,
2710 TosaArgGen.agConv,
2711 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002712 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002713 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002714 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2715 "error_if_validators": (
2716 TosaErrorValidator.evWrongInputType,
2717 TosaErrorValidator.evWrongOutputType,
2718 TosaErrorValidator.evWrongInputList,
2719 TosaErrorValidator.evWrongOutputList,
2720 TosaErrorValidator.evInputZeroPointNotZero,
2721 TosaErrorValidator.evWeightZeroPointNotZero,
2722 TosaErrorValidator.evPadSmallerZero,
2723 TosaErrorValidator.evStrideSmallerOne,
2724 TosaErrorValidator.evDilationSmallerOne,
2725 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002726 TosaErrorValidator.evConvOutputShapeMismatch,
2727 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002728 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002729 "template": True,
2730 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002731 # Templated operator. Filled in by createDynamicOpLists
2732 "conv3d_TEMPLATE": {
2733 "op": Op.CONV3D,
2734 "operands": (1, 2),
2735 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002736 "build_fcn": (
2737 build_conv3d,
2738 TosaTensorGen.tgConv3D,
2739 TosaTensorValuesGen.tvgDefault,
2740 TosaArgGen.agConv,
2741 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002742 "qgen": TosaQuantGen.qgConv,
2743 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002744 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2745 "error_if_validators": (
2746 TosaErrorValidator.evWrongInputType,
2747 TosaErrorValidator.evWrongOutputType,
2748 TosaErrorValidator.evWrongInputList,
2749 TosaErrorValidator.evWrongOutputList,
2750 TosaErrorValidator.evInputZeroPointNotZero,
2751 TosaErrorValidator.evWeightZeroPointNotZero,
2752 TosaErrorValidator.evPadSmallerZero,
2753 TosaErrorValidator.evStrideSmallerOne,
2754 TosaErrorValidator.evDilationSmallerOne,
2755 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002756 TosaErrorValidator.evConvOutputShapeMismatch,
2757 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002758 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002759 "template": True,
2760 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002761 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002762 "depthwise_conv2d_TEMPLATE": {
2763 "op": Op.DEPTHWISE_CONV2D,
2764 "operands": (1, 2),
2765 "filter": [1, 1],
2766 "rank": (4, 4),
2767 "build_fcn": (
2768 build_depthwise_conv2d,
2769 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002770 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002771 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002772 ),
2773 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002774 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002775 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2776 "error_if_validators": (
2777 TosaErrorValidator.evWrongInputType,
2778 TosaErrorValidator.evWrongOutputType,
2779 TosaErrorValidator.evWrongInputList,
2780 TosaErrorValidator.evWrongOutputList,
2781 TosaErrorValidator.evInputZeroPointNotZero,
2782 TosaErrorValidator.evWeightZeroPointNotZero,
2783 TosaErrorValidator.evPadSmallerZero,
2784 TosaErrorValidator.evStrideSmallerOne,
2785 TosaErrorValidator.evDilationSmallerOne,
2786 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002787 TosaErrorValidator.evConvOutputShapeMismatch,
2788 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002789 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002790 "template": True,
2791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002792 "fully_connected": {
2793 "op": Op.FULLY_CONNECTED,
2794 "operands": (1, 2),
2795 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002796 "build_fcn": (
2797 build_fully_connected,
2798 TosaTensorGen.tgFullyConnected,
2799 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002800 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002801 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002802 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002803 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002804 "error_if_validators": (
2805 TosaErrorValidator.evInputZeroPointNotZero,
2806 TosaErrorValidator.evWeightZeroPointNotZero,
2807 TosaErrorValidator.evWrongRank,
2808 TosaErrorValidator.evWrongInputType,
2809 TosaErrorValidator.evWrongOutputType,
2810 TosaErrorValidator.evWrongInputList,
2811 TosaErrorValidator.evWrongOutputList,
2812 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002813 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002814 "matmul": {
2815 "op": Op.MATMUL,
2816 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002817 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002818 "build_fcn": (
2819 build_matmul,
2820 TosaTensorGen.tgMatmul,
2821 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002822 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002823 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002824 "qgen": TosaQuantGen.qgMatmul,
2825 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002826 "error_if_validators": (
2827 TosaErrorValidator.evInputZeroPointNotZero,
2828 TosaErrorValidator.evWrongRank,
2829 TosaErrorValidator.evWrongInputType,
2830 TosaErrorValidator.evWrongOutputType,
2831 TosaErrorValidator.evWrongInputList,
2832 TosaErrorValidator.evWrongOutputList,
2833 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002834 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002835 "max_pool2d": {
2836 "op": Op.MAX_POOL2D,
2837 "operands": (1, 0),
2838 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002839 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002840 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002841 TosaTensorGen.tgNHWC,
2842 TosaTensorValuesGen.tvgDefault,
2843 TosaArgGen.agPooling,
2844 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002845 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002846 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002847 "error_if_validators": (
2848 TosaErrorValidator.evKernelSmallerOne,
2849 TosaErrorValidator.evStrideSmallerOne,
2850 TosaErrorValidator.evPadSmallerZero,
2851 TosaErrorValidator.evWrongRank,
2852 TosaErrorValidator.evWrongInputType,
2853 TosaErrorValidator.evWrongOutputType,
2854 TosaErrorValidator.evWrongInputList,
2855 TosaErrorValidator.evWrongOutputList,
2856 TosaErrorValidator.evPadLargerEqualKernel,
2857 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002858 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002859 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002860 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002861 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002862 "transpose_conv2d_TEMPLATE": {
2863 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002864 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002865 "rank": (4, 4),
2866 "build_fcn": (
2867 build_transpose_conv2d,
2868 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002869 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002870 TosaArgGen.agTransposeConv2D,
2871 ),
2872 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002873 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002874 "invalid_test_validators": (
2875 TosaInvalidValidator.ivHeightWidthInvalid,
2876 TosaInvalidValidator.ivNonPositiveOutputShape,
2877 ),
2878 "error_if_validators": (
2879 TosaErrorValidator.evWrongInputType,
2880 TosaErrorValidator.evWrongOutputType,
2881 TosaErrorValidator.evWrongInputList,
2882 TosaErrorValidator.evWrongOutputList,
2883 TosaErrorValidator.evInputZeroPointNotZero,
2884 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002885 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002886 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002887 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002888 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002889 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002890 "template": True,
2891 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002892 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002893 "clamp": {
2894 "op": Op.CLAMP,
2895 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002896 "build_fcn": (
2897 build_clamp,
2898 TosaTensorGen.tgBasic,
2899 TosaTensorValuesGen.tvgDefault,
2900 None,
2901 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002902 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002903 "error_if_validators": (
2904 TosaErrorValidator.evMaxSmallerMin,
2905 TosaErrorValidator.evWrongInputType,
2906 TosaErrorValidator.evWrongOutputType,
2907 TosaErrorValidator.evWrongInputList,
2908 TosaErrorValidator.evWrongOutputList,
2909 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002910 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002911 "sigmoid": {
2912 "op": Op.SIGMOID,
2913 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002914 "build_fcn": (
2915 build_sigmoid,
2916 TosaTensorGen.tgBasic,
2917 TosaTensorValuesGen.tvgDefault,
2918 None,
2919 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002920 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002921 "error_if_validators": (
2922 TosaErrorValidator.evWrongInputType,
2923 TosaErrorValidator.evWrongOutputType,
2924 TosaErrorValidator.evWrongInputList,
2925 TosaErrorValidator.evWrongOutputList,
2926 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002927 },
2928 "tanh": {
2929 "op": Op.TANH,
2930 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002931 "build_fcn": (
2932 build_tanh,
2933 TosaTensorGen.tgBasic,
2934 TosaTensorValuesGen.tvgDefault,
2935 None,
2936 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002937 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002938 "error_if_validators": (
2939 TosaErrorValidator.evWrongInputType,
2940 TosaErrorValidator.evWrongOutputType,
2941 TosaErrorValidator.evWrongInputList,
2942 TosaErrorValidator.evWrongOutputList,
2943 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002944 },
Won Jeon78155c62023-06-10 00:20:04 +00002945 "erf": {
2946 "op": Op.ERF,
2947 "operands": (1, 0),
2948 "build_fcn": (
2949 build_erf,
2950 TosaTensorGen.tgBasic,
2951 TosaTensorValuesGen.tvgDefault,
2952 None,
2953 ),
2954 "types": TYPE_FP,
2955 "error_if_validators": (
2956 TosaErrorValidator.evWrongInputType,
2957 TosaErrorValidator.evWrongOutputType,
2958 TosaErrorValidator.evWrongInputList,
2959 TosaErrorValidator.evWrongOutputList,
2960 ),
2961 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002962 # Elementwise Binary Operators
2963 "add": {
2964 "op": Op.ADD,
2965 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002966 "build_fcn": (
2967 build_binary_broadcast,
2968 TosaTensorGen.tgBroadcastFuzz,
2969 TosaTensorValuesGen.tvgAddSub,
2970 None,
2971 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002972 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002973 "error_if_validators": (
2974 TosaErrorValidator.evRankMismatch,
2975 TosaErrorValidator.evWrongInputType,
2976 TosaErrorValidator.evWrongOutputType,
2977 TosaErrorValidator.evWrongInputList,
2978 TosaErrorValidator.evWrongOutputList,
2979 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00002980 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002981 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002982 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002983 "arithmetic_right_shift": {
2984 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2985 "operands": (2, 0),
2986 "build_fcn": (
2987 build_arithmetic_right_shift,
2988 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002989 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002990 TosaArgGen.agArithmeticRightShift,
2991 ),
2992 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002993 "error_if_validators": (
2994 TosaErrorValidator.evRankMismatch,
2995 TosaErrorValidator.evWrongInputType,
2996 TosaErrorValidator.evWrongOutputType,
2997 TosaErrorValidator.evWrongInputList,
2998 TosaErrorValidator.evWrongOutputList,
2999 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003000 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003001 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003002 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003003 "bitwise_and": {
3004 "op": Op.BITWISE_AND,
3005 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003006 "build_fcn": (
3007 build_binary_broadcast,
3008 TosaTensorGen.tgBroadcastFuzz,
3009 TosaTensorValuesGen.tvgDefault,
3010 None,
3011 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003012 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003013 "error_if_validators": (
3014 TosaErrorValidator.evRankMismatch,
3015 TosaErrorValidator.evWrongInputType,
3016 TosaErrorValidator.evWrongOutputType,
3017 TosaErrorValidator.evWrongInputList,
3018 TosaErrorValidator.evWrongOutputList,
3019 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003020 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003021 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003022 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003023 "bitwise_or": {
3024 "op": Op.BITWISE_OR,
3025 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003026 "build_fcn": (
3027 build_binary_broadcast,
3028 TosaTensorGen.tgBroadcastFuzz,
3029 TosaTensorValuesGen.tvgDefault,
3030 None,
3031 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003032 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003033 "error_if_validators": (
3034 TosaErrorValidator.evRankMismatch,
3035 TosaErrorValidator.evWrongInputType,
3036 TosaErrorValidator.evWrongOutputType,
3037 TosaErrorValidator.evWrongInputList,
3038 TosaErrorValidator.evWrongOutputList,
3039 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003040 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003041 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003042 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003043 "bitwise_xor": {
3044 "op": Op.BITWISE_XOR,
3045 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003046 "build_fcn": (
3047 build_binary_broadcast,
3048 TosaTensorGen.tgBroadcastFuzz,
3049 TosaTensorValuesGen.tvgDefault,
3050 None,
3051 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003052 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003053 "error_if_validators": (
3054 TosaErrorValidator.evRankMismatch,
3055 TosaErrorValidator.evWrongInputType,
3056 TosaErrorValidator.evWrongOutputType,
3057 TosaErrorValidator.evWrongInputList,
3058 TosaErrorValidator.evWrongOutputList,
3059 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003060 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003061 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003062 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003063 "intdiv": {
3064 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003065 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003066 "build_fcn": (
3067 build_binary_broadcast,
3068 TosaTensorGen.tgBroadcastFuzz,
3069 TosaTensorValuesGen.tvgIntDiv,
3070 None,
3071 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003072 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003073 "error_if_validators": (
3074 TosaErrorValidator.evRankMismatch,
3075 TosaErrorValidator.evWrongInputType,
3076 TosaErrorValidator.evWrongOutputType,
3077 TosaErrorValidator.evWrongInputList,
3078 TosaErrorValidator.evWrongOutputList,
3079 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003080 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003081 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003082 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003083 "logical_and": {
3084 "op": Op.LOGICAL_AND,
3085 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003086 "build_fcn": (
3087 build_binary_broadcast,
3088 TosaTensorGen.tgBroadcastFuzz,
3089 TosaTensorValuesGen.tvgDefault,
3090 None,
3091 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003092 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003093 "error_if_validators": (
3094 TosaErrorValidator.evRankMismatch,
3095 TosaErrorValidator.evWrongInputType,
3096 TosaErrorValidator.evWrongOutputType,
3097 TosaErrorValidator.evWrongInputList,
3098 TosaErrorValidator.evWrongOutputList,
3099 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003100 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003101 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003102 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003103 "logical_left_shift": {
3104 "op": Op.LOGICAL_LEFT_SHIFT,
3105 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003106 "build_fcn": (
3107 build_binary_broadcast,
3108 TosaTensorGen.tgBroadcastFuzz,
3109 TosaTensorValuesGen.tvgLogicalShift,
3110 None,
3111 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003112 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003113 "error_if_validators": (
3114 TosaErrorValidator.evRankMismatch,
3115 TosaErrorValidator.evWrongInputType,
3116 TosaErrorValidator.evWrongOutputType,
3117 TosaErrorValidator.evWrongInputList,
3118 TosaErrorValidator.evWrongOutputList,
3119 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003120 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003121 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003122 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003123 "logical_right_shift": {
3124 "op": Op.LOGICAL_RIGHT_SHIFT,
3125 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003126 "build_fcn": (
3127 build_binary_broadcast,
3128 TosaTensorGen.tgBroadcastFuzz,
3129 TosaTensorValuesGen.tvgLogicalShift,
3130 None,
3131 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003132 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003133 "error_if_validators": (
3134 TosaErrorValidator.evRankMismatch,
3135 TosaErrorValidator.evWrongInputType,
3136 TosaErrorValidator.evWrongOutputType,
3137 TosaErrorValidator.evWrongInputList,
3138 TosaErrorValidator.evWrongOutputList,
3139 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003140 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003141 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 "logical_or": {
3144 "op": Op.LOGICAL_OR,
3145 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003146 "build_fcn": (
3147 build_binary_broadcast,
3148 TosaTensorGen.tgBroadcastFuzz,
3149 TosaTensorValuesGen.tvgDefault,
3150 None,
3151 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003152 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003153 "error_if_validators": (
3154 TosaErrorValidator.evRankMismatch,
3155 TosaErrorValidator.evWrongInputType,
3156 TosaErrorValidator.evWrongOutputType,
3157 TosaErrorValidator.evWrongInputList,
3158 TosaErrorValidator.evWrongOutputList,
3159 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003160 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003161 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003162 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003163 "logical_xor": {
3164 "op": Op.LOGICAL_XOR,
3165 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003166 "build_fcn": (
3167 build_binary_broadcast,
3168 TosaTensorGen.tgBroadcastFuzz,
3169 TosaTensorValuesGen.tvgDefault,
3170 None,
3171 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003172 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003173 "error_if_validators": (
3174 TosaErrorValidator.evRankMismatch,
3175 TosaErrorValidator.evWrongInputType,
3176 TosaErrorValidator.evWrongOutputType,
3177 TosaErrorValidator.evWrongInputList,
3178 TosaErrorValidator.evWrongOutputList,
3179 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003180 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003181 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003183 "maximum": {
3184 "op": Op.MAXIMUM,
3185 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003186 "build_fcn": (
3187 build_binary_broadcast,
3188 TosaTensorGen.tgBroadcastFuzz,
3189 TosaTensorValuesGen.tvgDefault,
3190 None,
3191 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003192 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003193 "error_if_validators": (
3194 TosaErrorValidator.evRankMismatch,
3195 TosaErrorValidator.evWrongInputType,
3196 TosaErrorValidator.evWrongOutputType,
3197 TosaErrorValidator.evWrongInputList,
3198 TosaErrorValidator.evWrongOutputList,
3199 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003200 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003201 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003202 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003203 "minimum": {
3204 "op": Op.MINIMUM,
3205 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003206 "build_fcn": (
3207 build_binary_broadcast,
3208 TosaTensorGen.tgBroadcastFuzz,
3209 TosaTensorValuesGen.tvgDefault,
3210 None,
3211 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003212 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003213 "error_if_validators": (
3214 TosaErrorValidator.evRankMismatch,
3215 TosaErrorValidator.evWrongInputType,
3216 TosaErrorValidator.evWrongOutputType,
3217 TosaErrorValidator.evWrongInputList,
3218 TosaErrorValidator.evWrongOutputList,
3219 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003220 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003221 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003222 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003223 "mul": {
3224 "op": Op.MUL,
3225 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003226 "build_fcn": (
3227 build_mul,
3228 TosaTensorGen.tgBroadcastFuzz,
3229 TosaTensorValuesGen.tvgMul,
3230 TosaArgGen.agMul,
3231 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003232 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003233 "error_if_validators": (
3234 TosaErrorValidator.evWrongInputType,
3235 TosaErrorValidator.evWrongOutputType,
3236 TosaErrorValidator.evWrongInputList,
3237 TosaErrorValidator.evWrongOutputList,
3238 TosaErrorValidator.evRankMismatch,
3239 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003240 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003241 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003242 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 "pow": {
3244 "op": Op.POW,
3245 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003246 "build_fcn": (
3247 build_binary_broadcast,
3248 TosaTensorGen.tgBroadcastFuzz,
3249 TosaTensorValuesGen.tvgDefault,
3250 None,
3251 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003252 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003253 "error_if_validators": (
3254 TosaErrorValidator.evRankMismatch,
3255 TosaErrorValidator.evWrongInputType,
3256 TosaErrorValidator.evWrongOutputType,
3257 TosaErrorValidator.evWrongInputList,
3258 TosaErrorValidator.evWrongOutputList,
3259 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003260 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003261 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 "sub": {
3264 "op": Op.SUB,
3265 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003266 "build_fcn": (
3267 build_binary_broadcast,
3268 TosaTensorGen.tgBroadcastFuzz,
3269 TosaTensorValuesGen.tvgAddSub,
3270 None,
3271 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003272 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003273 "error_if_validators": (
3274 TosaErrorValidator.evRankMismatch,
3275 TosaErrorValidator.evWrongInputType,
3276 TosaErrorValidator.evWrongOutputType,
3277 TosaErrorValidator.evWrongInputList,
3278 TosaErrorValidator.evWrongOutputList,
3279 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003280 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003281 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 "table": {
3284 "op": Op.TABLE,
3285 # Use the automatic generation functions to create the input array
3286 # but create the table tensor in the build function, as it may be
3287 # a different type from the input
3288 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 "build_fcn": (
3290 build_table,
3291 TosaTensorGen.tgBasic,
3292 TosaTensorValuesGen.tvgDefault,
3293 TosaArgGen.agTable,
3294 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003295 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003296 "error_if_validators": (
3297 TosaErrorValidator.evWrongInputType,
3298 TosaErrorValidator.evWrongOutputType,
3299 TosaErrorValidator.evWrongInputList,
3300 TosaErrorValidator.evWrongOutputList,
3301 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003302 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 # Elementwise Unary operators
3304 "abs": {
3305 "op": Op.ABS,
3306 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003307 "build_fcn": (
3308 build_unary,
3309 TosaTensorGen.tgBasic,
3310 TosaTensorValuesGen.tvgDefault,
3311 None,
3312 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003313 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003314 "error_if_validators": (
3315 TosaErrorValidator.evWrongInputType,
3316 TosaErrorValidator.evWrongOutputType,
3317 TosaErrorValidator.evWrongInputList,
3318 TosaErrorValidator.evWrongOutputList,
3319 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003320 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003321 "bitwise_not": {
3322 "op": Op.BITWISE_NOT,
3323 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003324 "build_fcn": (
3325 build_unary,
3326 TosaTensorGen.tgBasic,
3327 TosaTensorValuesGen.tvgDefault,
3328 None,
3329 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003330 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003331 "error_if_validators": (
3332 TosaErrorValidator.evWrongInputType,
3333 TosaErrorValidator.evWrongOutputType,
3334 TosaErrorValidator.evWrongInputList,
3335 TosaErrorValidator.evWrongOutputList,
3336 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "ceil": {
3339 "op": Op.CEIL,
3340 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003341 "build_fcn": (
3342 build_unary,
3343 TosaTensorGen.tgBasic,
3344 TosaTensorValuesGen.tvgDefault,
3345 None,
3346 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003347 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003348 "error_if_validators": (
3349 TosaErrorValidator.evWrongInputType,
3350 TosaErrorValidator.evWrongOutputType,
3351 TosaErrorValidator.evWrongInputList,
3352 TosaErrorValidator.evWrongOutputList,
3353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003355 "clz": {
3356 "op": Op.CLZ,
3357 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003358 "build_fcn": (
3359 build_unary,
3360 TosaTensorGen.tgBasic,
3361 TosaTensorValuesGen.tvgDefault,
3362 None,
3363 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003364 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003365 "error_if_validators": (
3366 TosaErrorValidator.evWrongInputType,
3367 TosaErrorValidator.evWrongOutputType,
3368 TosaErrorValidator.evWrongInputList,
3369 TosaErrorValidator.evWrongOutputList,
3370 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003371 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003372 "exp": {
3373 "op": Op.EXP,
3374 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003375 "build_fcn": (
3376 build_unary,
3377 TosaTensorGen.tgBasic,
3378 TosaTensorValuesGen.tvgDefault,
3379 None,
3380 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003381 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003382 "error_if_validators": (
3383 TosaErrorValidator.evWrongInputType,
3384 TosaErrorValidator.evWrongOutputType,
3385 TosaErrorValidator.evWrongInputList,
3386 TosaErrorValidator.evWrongOutputList,
3387 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003388 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003389 "floor": {
3390 "op": Op.FLOOR,
3391 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003392 "build_fcn": (
3393 build_unary,
3394 TosaTensorGen.tgBasic,
3395 TosaTensorValuesGen.tvgDefault,
3396 None,
3397 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003399 "error_if_validators": (
3400 TosaErrorValidator.evWrongInputType,
3401 TosaErrorValidator.evWrongOutputType,
3402 TosaErrorValidator.evWrongInputList,
3403 TosaErrorValidator.evWrongOutputList,
3404 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003405 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 "log": {
3407 "op": Op.LOG,
3408 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003409 "build_fcn": (
3410 build_unary,
3411 TosaTensorGen.tgBasic,
3412 TosaTensorValuesGen.tvgDefault,
3413 None,
3414 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003415 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003416 "error_if_validators": (
3417 TosaErrorValidator.evWrongInputType,
3418 TosaErrorValidator.evWrongOutputType,
3419 TosaErrorValidator.evWrongInputList,
3420 TosaErrorValidator.evWrongOutputList,
3421 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003422 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003423 "logical_not": {
3424 "op": Op.LOGICAL_NOT,
3425 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003426 "build_fcn": (
3427 build_unary,
3428 TosaTensorGen.tgBasic,
3429 TosaTensorValuesGen.tvgDefault,
3430 None,
3431 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003432 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003433 "error_if_validators": (
3434 TosaErrorValidator.evWrongInputType,
3435 TosaErrorValidator.evWrongOutputType,
3436 TosaErrorValidator.evWrongInputList,
3437 TosaErrorValidator.evWrongOutputList,
3438 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003439 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003440 "negate": {
3441 "op": Op.NEGATE,
3442 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003443 "build_fcn": (
3444 build_unary,
3445 TosaTensorGen.tgBasic,
3446 TosaTensorValuesGen.tvgNegate,
3447 None,
3448 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003449 "qgen": TosaQuantGen.qgUnary,
3450 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003451 "error_if_validators": (
3452 TosaErrorValidator.evInputZeroPointNotZero,
3453 TosaErrorValidator.evOutputZeroPointNotZero,
3454 TosaErrorValidator.evWrongInputType,
3455 TosaErrorValidator.evWrongOutputType,
3456 TosaErrorValidator.evWrongInputList,
3457 TosaErrorValidator.evWrongOutputList,
3458 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003459 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003460 "reciprocal": {
3461 "op": Op.RECIPROCAL,
3462 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003463 "build_fcn": (
3464 build_unary,
3465 TosaTensorGen.tgBasic,
3466 TosaTensorValuesGen.tvgDefault,
3467 None,
3468 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003469 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003470 "error_if_validators": (
3471 TosaErrorValidator.evWrongInputType,
3472 TosaErrorValidator.evWrongOutputType,
3473 TosaErrorValidator.evWrongInputList,
3474 TosaErrorValidator.evWrongOutputList,
3475 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003476 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003477 "rsqrt": {
3478 "op": Op.RSQRT,
3479 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003480 "build_fcn": (
3481 build_unary,
3482 TosaTensorGen.tgBasic,
3483 TosaTensorValuesGen.tvgDefault,
3484 None,
3485 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003486 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003487 "error_if_validators": (
3488 TosaErrorValidator.evWrongInputType,
3489 TosaErrorValidator.evWrongOutputType,
3490 TosaErrorValidator.evWrongInputList,
3491 TosaErrorValidator.evWrongOutputList,
3492 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003493 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003494 # Elementwise Ternary operators
3495 "select": {
3496 "op": Op.SELECT,
3497 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003498 "build_fcn": (
3499 build_select,
3500 TosaTensorGen.tgBroadcastFuzz,
3501 TosaTensorValuesGen.tvgSelect,
3502 None,
3503 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003505 "error_if_validators": (
3506 TosaErrorValidator.evRankMismatch,
3507 TosaErrorValidator.evWrongInputType,
3508 TosaErrorValidator.evWrongOutputType,
3509 TosaErrorValidator.evWrongInputList,
3510 TosaErrorValidator.evWrongOutputList,
3511 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003512 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003513 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003514 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 # Comparison operators
3516 "equal": {
3517 "op": Op.EQUAL,
3518 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 "build_fcn": (
3520 build_comparison,
3521 TosaTensorGen.tgBroadcastFuzz,
3522 TosaTensorValuesGen.tvgEqual,
3523 None,
3524 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003526 "error_if_validators": (
3527 TosaErrorValidator.evRankMismatch,
3528 TosaErrorValidator.evWrongInputType,
3529 TosaErrorValidator.evWrongOutputType,
3530 TosaErrorValidator.evWrongInputList,
3531 TosaErrorValidator.evWrongOutputList,
3532 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003533 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003534 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 "greater_equal": {
3537 "op": Op.GREATER_EQUAL,
3538 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003539 "build_fcn": (
3540 build_comparison,
3541 TosaTensorGen.tgBroadcastFuzz,
3542 TosaTensorValuesGen.tvgDefault,
3543 None,
3544 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003545 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003546 "error_if_validators": (
3547 TosaErrorValidator.evRankMismatch,
3548 TosaErrorValidator.evWrongInputType,
3549 TosaErrorValidator.evWrongOutputType,
3550 TosaErrorValidator.evWrongInputList,
3551 TosaErrorValidator.evWrongOutputList,
3552 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003553 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003555 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003556 "greater": {
3557 "op": Op.GREATER,
3558 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003559 "build_fcn": (
3560 build_comparison,
3561 TosaTensorGen.tgBroadcastFuzz,
3562 TosaTensorValuesGen.tvgDefault,
3563 None,
3564 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003565 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003566 "error_if_validators": (
3567 TosaErrorValidator.evRankMismatch,
3568 TosaErrorValidator.evWrongInputType,
3569 TosaErrorValidator.evWrongOutputType,
3570 TosaErrorValidator.evWrongInputList,
3571 TosaErrorValidator.evWrongOutputList,
3572 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003573 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003574 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003576 # Reduction operators
3577 "reduce_all": {
3578 "op": Op.REDUCE_ALL,
3579 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003580 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003581 "build_fcn": (
3582 build_reduce,
3583 TosaTensorGen.tgBasic,
3584 TosaTensorValuesGen.tvgDefault,
3585 TosaArgGen.agAxis,
3586 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003587 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003588 "error_if_validators": (
3589 TosaErrorValidator.evAxisLargerRank,
3590 TosaErrorValidator.evAxisSmallerZero,
3591 TosaErrorValidator.evShapeOfAxisNotOne,
3592 TosaErrorValidator.evWrongInputType,
3593 TosaErrorValidator.evWrongOutputType,
3594 TosaErrorValidator.evWrongRank,
3595 TosaErrorValidator.evWrongInputList,
3596 TosaErrorValidator.evWrongOutputList,
3597 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003598 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003599 "reduce_any": {
3600 "op": Op.REDUCE_ANY,
3601 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003602 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003603 "build_fcn": (
3604 build_reduce,
3605 TosaTensorGen.tgBasic,
3606 TosaTensorValuesGen.tvgDefault,
3607 TosaArgGen.agAxis,
3608 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003609 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003610 "error_if_validators": (
3611 TosaErrorValidator.evAxisLargerRank,
3612 TosaErrorValidator.evAxisSmallerZero,
3613 TosaErrorValidator.evShapeOfAxisNotOne,
3614 TosaErrorValidator.evWrongInputType,
3615 TosaErrorValidator.evWrongOutputType,
3616 TosaErrorValidator.evWrongRank,
3617 TosaErrorValidator.evWrongInputList,
3618 TosaErrorValidator.evWrongOutputList,
3619 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003620 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003621 "reduce_max": {
3622 "op": Op.REDUCE_MAX,
3623 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003624 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003625 "build_fcn": (
3626 build_reduce,
3627 TosaTensorGen.tgBasic,
3628 TosaTensorValuesGen.tvgDefault,
3629 TosaArgGen.agAxis,
3630 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003631 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003632 "error_if_validators": (
3633 TosaErrorValidator.evAxisLargerRank,
3634 TosaErrorValidator.evAxisSmallerZero,
3635 TosaErrorValidator.evShapeOfAxisNotOne,
3636 TosaErrorValidator.evWrongInputType,
3637 TosaErrorValidator.evWrongOutputType,
3638 TosaErrorValidator.evWrongRank,
3639 TosaErrorValidator.evWrongInputList,
3640 TosaErrorValidator.evWrongOutputList,
3641 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003642 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003643 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003644 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003645 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003646 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003647 "build_fcn": (
3648 build_reduce,
3649 TosaTensorGen.tgBasic,
3650 TosaTensorValuesGen.tvgDefault,
3651 TosaArgGen.agAxis,
3652 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003653 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003654 "error_if_validators": (
3655 TosaErrorValidator.evAxisLargerRank,
3656 TosaErrorValidator.evAxisSmallerZero,
3657 TosaErrorValidator.evShapeOfAxisNotOne,
3658 TosaErrorValidator.evWrongInputType,
3659 TosaErrorValidator.evWrongOutputType,
3660 TosaErrorValidator.evWrongRank,
3661 TosaErrorValidator.evWrongInputList,
3662 TosaErrorValidator.evWrongOutputList,
3663 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003664 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003665 "reduce_product": {
3666 "op": Op.REDUCE_PRODUCT,
3667 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003668 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003669 "build_fcn": (
3670 build_reduce,
3671 TosaTensorGen.tgBasic,
3672 TosaTensorValuesGen.tvgDefault,
3673 TosaArgGen.agAxis,
3674 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003675 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003676 "error_if_validators": (
3677 TosaErrorValidator.evAxisLargerRank,
3678 TosaErrorValidator.evAxisSmallerZero,
3679 TosaErrorValidator.evShapeOfAxisNotOne,
3680 TosaErrorValidator.evWrongInputType,
3681 TosaErrorValidator.evWrongOutputType,
3682 TosaErrorValidator.evWrongRank,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 "reduce_sum": {
3688 "op": Op.REDUCE_SUM,
3689 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003690 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003691 "build_fcn": (
3692 build_reduce,
3693 TosaTensorGen.tgBasic,
3694 TosaTensorValuesGen.tvgReduceSum,
3695 TosaArgGen.agAxis,
3696 ),
James Ward24dbc422022-10-19 12:20:31 +01003697 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003698 "error_if_validators": (
3699 TosaErrorValidator.evAxisLargerRank,
3700 TosaErrorValidator.evAxisSmallerZero,
3701 TosaErrorValidator.evShapeOfAxisNotOne,
3702 TosaErrorValidator.evWrongInputType,
3703 TosaErrorValidator.evWrongOutputType,
3704 TosaErrorValidator.evWrongRank,
3705 TosaErrorValidator.evWrongInputList,
3706 TosaErrorValidator.evWrongOutputList,
3707 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003708 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003709 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003710 "concat": {
3711 "op": Op.CONCAT,
3712 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003713 "build_fcn": (
3714 build_concat,
3715 TosaTensorGen.tgConcat,
3716 TosaTensorValuesGen.tvgConcat,
3717 TosaArgGen.agAxis,
3718 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003719 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003720 "error_if_validators": (
3721 TosaErrorValidator.evAxisLargerRank,
3722 TosaErrorValidator.evAxisSmallerZero,
3723 TosaErrorValidator.evConcatInputRankMismatch,
3724 TosaErrorValidator.evConcatShapeSumMismatch,
3725 TosaErrorValidator.evConcatInputDimMismatch,
3726 TosaErrorValidator.evWrongInputType,
3727 TosaErrorValidator.evWrongOutputType,
3728 TosaErrorValidator.evWrongOutputList,
3729 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003730 },
3731 "pad": {
3732 "op": Op.PAD,
3733 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003734 "build_fcn": (
3735 build_pad,
3736 TosaTensorGen.tgBasic,
3737 TosaTensorValuesGen.tvgDefault,
3738 TosaArgGen.agPad,
3739 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003740 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003741 "error_if_validators": (
3742 TosaErrorValidator.evWrongInputType,
3743 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003744 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003745 TosaErrorValidator.evWrongOutputType,
3746 TosaErrorValidator.evWrongInputList,
3747 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003748 TosaErrorValidator.evRankMismatch,
3749 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003750 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003751 },
3752 "reshape": {
3753 "op": Op.RESHAPE,
3754 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003755 "build_fcn": (
3756 build_reshape,
3757 TosaTensorGen.tgBasic,
3758 TosaTensorValuesGen.tvgDefault,
3759 TosaArgGen.agReshape,
3760 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003761 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 "error_if_validators": (
3763 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3764 TosaErrorValidator.evWrongInputType,
3765 TosaErrorValidator.evWrongOutputType,
3766 TosaErrorValidator.evWrongInputList,
3767 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00003768 TosaErrorValidator.evReshapeOutputSizeMultiInference,
3769 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003770 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003771 },
3772 "reverse": {
3773 "op": Op.REVERSE,
3774 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003775 "build_fcn": (
3776 build_reverse,
3777 TosaTensorGen.tgBasic,
3778 TosaTensorValuesGen.tvgDefault,
3779 TosaArgGen.agAxis,
3780 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003781 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003782 "error_if_validators": (
3783 TosaErrorValidator.evAxisSmallerZero,
3784 TosaErrorValidator.evAxisLargerRank,
3785 TosaErrorValidator.evWrongInputType,
3786 TosaErrorValidator.evWrongOutputType,
3787 TosaErrorValidator.evWrongInputList,
3788 TosaErrorValidator.evWrongOutputList,
3789 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003790 },
3791 "slice": {
3792 "op": Op.SLICE,
3793 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003794 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003795 "build_fcn": (
3796 build_slice,
3797 TosaTensorGen.tgBasic,
3798 TosaTensorValuesGen.tvgDefault,
3799 TosaArgGen.agSlice,
3800 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003801 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003802 "error_if_validators": (
3803 TosaErrorValidator.evStartSmallerZero,
3804 TosaErrorValidator.evSizeSmallerEqualZero,
3805 TosaErrorValidator.evStartSizeOutsideBounds,
3806 TosaErrorValidator.evSizeOutputShapeMismatch,
3807 TosaErrorValidator.evInputSizeStartLengthMismatch,
3808 TosaErrorValidator.evWrongRank,
3809 TosaErrorValidator.evWrongInputType,
3810 TosaErrorValidator.evWrongOutputType,
3811 TosaErrorValidator.evWrongInputList,
3812 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003813 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003814 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003815 },
3816 "tile": {
3817 "op": Op.TILE,
3818 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003819 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003820 "build_fcn": (
3821 build_tile,
3822 TosaTensorGen.tgBasic,
3823 TosaTensorValuesGen.tvgDefault,
3824 TosaArgGen.agTile,
3825 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003826 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003827 "error_if_validators": (
3828 TosaErrorValidator.evWrongInputType,
3829 TosaErrorValidator.evWrongOutputType,
3830 TosaErrorValidator.evWrongInputList,
3831 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003832 TosaErrorValidator.evRankMismatch,
3833 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003834 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003835 },
3836 "transpose": {
3837 "op": Op.TRANSPOSE,
3838 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003839 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003840 "build_fcn": (
3841 build_transpose,
3842 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003843 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003844 TosaArgGen.agTranspose,
3845 ),
3846 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003847 "error_if_validators": (
3848 TosaErrorValidator.evIndexOutsideBounds,
3849 TosaErrorValidator.evIndexUsedTwice,
3850 TosaErrorValidator.evWrongInputType,
3851 TosaErrorValidator.evWrongOutputType,
3852 TosaErrorValidator.evWrongInputList,
3853 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003854 TosaErrorValidator.evWrongRank,
3855 TosaErrorValidator.evRankMismatch,
3856 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003857 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003858 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003859 # Data nodes
3860 "const": {
3861 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003862 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003863 "build_fcn": (
3864 build_const,
3865 TosaTensorGen.tgBasic,
3866 TosaTensorValuesGen.tvgDefault,
3867 None,
3868 ),
Luke Hutton65872422023-02-20 10:33:04 +00003869 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08003870 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003871 "identity": {
3872 "op": Op.IDENTITY,
3873 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003874 "build_fcn": (
3875 build_unary,
3876 TosaTensorGen.tgBasic,
3877 TosaTensorValuesGen.tvgDefault,
3878 None,
3879 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003880 "types": TYPE_FIB,
3881 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003882 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003883 "gather": {
3884 "op": Op.GATHER,
3885 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3886 "operands": (1, 0),
3887 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003888 "build_fcn": (
3889 build_gather,
3890 TosaTensorGen.tgBasic,
3891 TosaTensorValuesGen.tvgDefault,
3892 None,
3893 ),
James Ward24dbc422022-10-19 12:20:31 +01003894 "types": (
3895 DType.INT8,
3896 DType.INT16,
3897 DType.INT32,
3898 DType.FP16,
3899 DType.BF16,
3900 DType.FP32,
3901 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003902 "error_if_validators": (
3903 TosaErrorValidator.evWrongInputType,
3904 TosaErrorValidator.evWrongOutputType,
3905 TosaErrorValidator.evWrongInputList,
3906 TosaErrorValidator.evWrongOutputList,
3907 TosaErrorValidator.evWrongRank,
3908 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003909 },
3910 "scatter": {
3911 "op": Op.SCATTER,
3912 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003913 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003914 "operands": (2, 0),
3915 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003916 "build_fcn": (
3917 build_scatter,
3918 TosaTensorGen.tgScatter,
3919 TosaTensorValuesGen.tvgDefault,
3920 None,
3921 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003922 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003923 "error_if_validators": (
3924 TosaErrorValidator.evWrongInputType,
3925 TosaErrorValidator.evWrongOutputType,
3926 TosaErrorValidator.evWrongInputList,
3927 TosaErrorValidator.evWrongOutputList,
3928 TosaErrorValidator.evWrongRank,
3929 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003930 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003931 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003932 "resize": {
3933 "op": Op.RESIZE,
3934 "operands": (1, 0),
3935 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003936 "build_fcn": (
3937 build_resize,
3938 TosaTensorGen.tgNHWC,
3939 TosaTensorValuesGen.tvgDefault,
3940 TosaArgGen.agResize,
3941 ),
James Ward24dbc422022-10-19 12:20:31 +01003942 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003943 "invalid_test_validators": (
3944 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003945 ),
3946 "error_if_validators": (
3947 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003948 TosaErrorValidator.evScaleSmallerEqualZero,
3949 TosaErrorValidator.evScaleNLargerMax,
3950 TosaErrorValidator.evScaleDLargerMax,
3951 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003952 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003953 TosaErrorValidator.evBorderSmallerMin,
3954 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003955 TosaErrorValidator.evWrongInputType,
3956 TosaErrorValidator.evWrongOutputType,
3957 TosaErrorValidator.evWrongRank,
3958 TosaErrorValidator.evWrongInputList,
3959 TosaErrorValidator.evWrongOutputList,
3960 TosaErrorValidator.evBatchMismatch,
3961 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003962 TosaErrorValidator.evResizeOutputShapeMismatch,
3963 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003964 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003965 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003966 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003967 "cast": {
3968 "op": Op.CAST,
3969 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003970 "build_fcn": (
3971 build_cast,
3972 TosaTensorGen.tgBasic,
3973 TosaTensorValuesGen.tvgDefault,
3974 TosaArgGen.agCast,
3975 ),
James Ward8b390432022-08-12 20:48:56 +01003976 "types": (
3977 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003978 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003979 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003980 DType.INT8,
3981 DType.INT16,
3982 DType.INT32,
3983 DType.BOOL,
3984 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003985 "error_if_validators": (
3986 TosaErrorValidator.evWrongInputType,
3987 TosaErrorValidator.evWrongOutputType,
3988 TosaErrorValidator.evWrongInputList,
3989 TosaErrorValidator.evWrongOutputList,
3990 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003991 },
3992 "rescale": {
3993 "op": Op.RESCALE,
3994 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003995 "build_fcn": (
3996 build_rescale,
3997 TosaTensorGen.tgBasic,
3998 TosaTensorValuesGen.tvgDefault,
3999 TosaArgGen.agRescale,
4000 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004001 "types": [
4002 DType.UINT8,
4003 DType.INT8,
4004 DType.INT16,
4005 DType.INT32,
4006 DType.INT48,
4007 DType.UINT16,
4008 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004009 "error_if_validators": (
4010 TosaErrorValidator.evInputZeroPointNotZero,
4011 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004012 TosaErrorValidator.evU16InputZeroPointNotValid,
4013 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004014 TosaErrorValidator.evScaleTrue,
4015 TosaErrorValidator.evScaleNotTrue,
4016 TosaErrorValidator.evWrongInputType,
4017 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004018 TosaErrorValidator.evWrongInputList,
4019 TosaErrorValidator.evWrongOutputList,
4020 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004021 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004022 # Custom
4023 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004024 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004025 # Two varients of cond_if, one that generates one of two constant tensors (no
4026 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4027 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004028 "cond_if_const": {
4029 "op": Op.COND_IF,
4030 "operands": (0, 2),
4031 "build_fcn": (
4032 build_cond_if_const,
4033 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004034 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004035 TosaArgGen.agCondIf,
4036 ),
4037 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004038 "error_if_validators": (
4039 TosaErrorValidator.evOutputListThenGraphMismatch,
4040 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004041 TosaErrorValidator.evCondIfCondNotMatchingBool,
4042 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004043 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004044 },
4045 "cond_if_binary": {
4046 "op": Op.COND_IF,
4047 "operands": (2, 0),
4048 "build_fcn": (
4049 build_cond_if_binary,
4050 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004051 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004052 TosaArgGen.agCondIf,
4053 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004054 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004055 "error_if_validators": (
4056 TosaErrorValidator.evInputListThenGraphMismatch,
4057 TosaErrorValidator.evInputListElseGraphMismatch,
4058 TosaErrorValidator.evOutputListThenGraphMismatch,
4059 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004060 TosaErrorValidator.evCondIfCondNotMatchingBool,
4061 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004062 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004063 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004064 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004065 "while_loop": {
4066 "op": Op.WHILE_LOOP,
4067 "operands": (0, 1),
4068 "build_fcn": (
4069 build_while_loop,
4070 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004071 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004072 TosaArgGen.agWhileLoop,
4073 ),
4074 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004075 "error_if_validators": (
4076 TosaErrorValidator.evInputListOutputListMismatch,
4077 TosaErrorValidator.evInputListCondGraphMismatch,
4078 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4079 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4080 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004081 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004082 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004083 },
Luke Hutton57287132023-02-06 14:54:18 +00004084 "fft2d": {
4085 "op": Op.FFT2D,
4086 "operands": (2, 0),
4087 "rank": (3, 3),
4088 "build_fcn": (
4089 build_fft2d,
4090 TosaTensorGen.tgFFT2d,
4091 TosaTensorValuesGen.tvgDefault,
4092 TosaArgGen.agFFT2d,
4093 ),
4094 "types": [DType.FP32],
4095 "error_if_validators": (
4096 TosaErrorValidator.evWrongInputType,
4097 TosaErrorValidator.evWrongOutputType,
4098 TosaErrorValidator.evWrongInputList,
4099 TosaErrorValidator.evWrongOutputList,
4100 TosaErrorValidator.evWrongRank,
4101 TosaErrorValidator.evBatchMismatch,
4102 TosaErrorValidator.evKernelNotPowerOfTwo,
4103 TosaErrorValidator.evFFTInputShapeMismatch,
4104 TosaErrorValidator.evFFTOutputShapeMismatch,
4105 ),
4106 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004107 "rfft2d": {
4108 "op": Op.RFFT2D,
4109 "operands": (1, 0),
4110 "rank": (3, 3),
4111 "build_fcn": (
4112 build_rfft2d,
4113 TosaTensorGen.tgRFFT2d,
4114 TosaTensorValuesGen.tvgDefault,
4115 TosaArgGen.agNone,
4116 ),
4117 "types": [DType.FP32],
4118 "error_if_validators": (
4119 TosaErrorValidator.evWrongInputType,
4120 TosaErrorValidator.evWrongOutputType,
4121 TosaErrorValidator.evWrongInputList,
4122 TosaErrorValidator.evWrongOutputList,
4123 TosaErrorValidator.evWrongRank,
4124 TosaErrorValidator.evBatchMismatch,
4125 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004126 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004127 ),
4128 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004129 }
4130
Kevin Cheng550ccc52021-03-03 11:21:43 -08004131
Eric Kunzee5e26762020-10-13 16:11:07 -07004132class OutputShaper:
4133 # Methods in this class compute the expected output shape and datatype
4134 # for common classes of operations
4135 def __init__(self):
4136 pass
4137
4138 # These methods return arguments that can be used for
4139 # creating a new output tensor
4140 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004141 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4142 if error_name != ErrorIf.RankMismatch:
4143 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004144 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004145
4146 shape = []
4147 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004148 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004149 shape.append(b.shape[i])
4150 else:
4151 shape.append(a.shape[i])
4152
Jerry Ge135c9552023-05-23 20:59:32 +00004153 fuzz_idx = rng.integers(0, len(a.shape))
4154 if error_name == ErrorIf.DimensionMismatch:
4155 shape[fuzz_idx] += 1
4156
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004157 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004158 all_dtypes = [
4159 DType.INT8,
4160 DType.INT16,
4161 DType.INT32,
4162 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004163 DType.FP16,
4164 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004165 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004166 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004167 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4168 outputDType = rng.choice(wrong_dtypes)
4169 else:
4170 outputDType = a.dtype
4171
4172 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004173
4174 @staticmethod
4175 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004176 assert len(a.shape) == len(b.shape)
4177 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004178
4179 shape = []
4180 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004181 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004182 shape.append(a.shape[i])
4183
Kevin Cheng550ccc52021-03-03 11:21:43 -08004184 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004185
4186 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004187 def unaryOp(ser, rng, a, error_name=None):
4188 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004189 all_dtypes = [
4190 DType.INT8,
4191 DType.INT16,
4192 DType.INT32,
4193 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004194 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004195 DType.FP16,
4196 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004197 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004198 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4199 outputDType = rng.choice(wrong_dtypes)
4200 else:
4201 outputDType = a.dtype
4202
4203 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004204
4205 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004206 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004207 if error_name != ErrorIf.RankMismatch:
4208 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004209 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004210
4211 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004212 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004213 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004214 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4215 else:
4216 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004217
Jerry Ge135c9552023-05-23 20:59:32 +00004218 fuzz_idx = rng.integers(0, len(a.shape))
4219 if error_name == ErrorIf.DimensionMismatch:
4220 shape[fuzz_idx] += 1
4221
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004222 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004223 all_dtypes = [
4224 DType.INT8,
4225 DType.INT16,
4226 DType.INT32,
4227 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004228 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004229 DType.FP16,
4230 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004231 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004232 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4233 outputDType = rng.choice(wrong_dtypes)
4234 else:
4235 outputDType = a.dtype
4236
4237 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004238
4239 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004240 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004241 if error_name != ErrorIf.RankMismatch:
4242 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004243 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004244
4245 # Do broadcast
4246 shape = []
4247 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004248 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004249 shape.append(b.shape[i])
4250 else:
4251 shape.append(a.shape[i])
4252
Jerry Ge135c9552023-05-23 20:59:32 +00004253 fuzz_idx = rng.integers(0, len(a.shape))
4254 if error_name == ErrorIf.DimensionMismatch:
4255 shape[fuzz_idx] += 1
4256
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004257 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004258 wrong_dtypes = [
4259 DType.INT8,
4260 DType.INT16,
4261 DType.INT32,
4262 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004263 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004264 DType.FP16,
4265 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004266 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004267 outputDType = rng.choice(wrong_dtypes)
4268 else:
4269 outputDType = DType.BOOL
4270
4271 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004272
4273 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004274 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004275 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004276 if error_name not in [
4277 ErrorIf.AxisSmallerZero,
4278 ErrorIf.AxisLargerRank,
4279 ErrorIf.ShapeOfAxisNotOne,
4280 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004281 shape[axis] = 1
4282 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4283 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004284
Matthew Haddond6ce7252021-09-29 15:35:44 +01004285 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004286 all_dtypes = [
4287 DType.INT8,
4288 DType.INT16,
4289 DType.INT32,
4290 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004291 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004292 DType.FP16,
4293 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004294 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004295 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4296 outputDType = rng.choice(wrong_dtypes)
4297 else:
4298 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004299
Matthew Haddond6ce7252021-09-29 15:35:44 +01004300 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004301
4302 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004303 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004304 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004305
4306 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4307 del shape[axis]
4308
4309 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4310 remove = rng.choice([True, False])
4311 if remove and len(shape) > 1:
4312 del shape[0]
4313 else:
4314 shape.append(1)
4315 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4316 for i in range(len(shape)):
4317 shape[i] = shape[i] + rng.integers(1, 10)
4318
4319 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004320 all_dtypes = [
4321 DType.INT8,
4322 DType.INT16,
4323 DType.INT32,
4324 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004325 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004326 DType.FP16,
4327 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004328 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004329 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4330 outputDType = rng.choice(wrong_dtypes)
4331 else:
4332 outputDType = DType.INT32
4333
4334 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004335
4336 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004337 def conv2dOp(
4338 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4339 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004340
4341 # IFM: NHWC
4342 # Filter: OHWI
4343 # OFM: NHWC
4344
Kevin Cheng550ccc52021-03-03 11:21:43 -08004345 h = (
4346 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004347 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004348 + padding[0]
4349 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004350 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004351 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004352
Kevin Cheng550ccc52021-03-03 11:21:43 -08004353 w = (
4354 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004355 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004356 + padding[2]
4357 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004358 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004359 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004360
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004361 if error_name == ErrorIf.ConvOutputShapeMismatch:
4362 choices = [1, 2, 3]
4363 change = rng.choice(choices)
4364 # increment in multiples of stride to not hit non-integer error case
4365 if change in [1, 3]:
4366 h = h + (rng.choice(choices) * strides[0])
4367 if change in [2, 3]:
4368 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004369
Eric Kunzee5e26762020-10-13 16:11:07 -07004370 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4371
James Ward8b390432022-08-12 20:48:56 +01004372 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004373 # Pick some potentially correct output dtype if input type is incorrect
4374 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004375 else:
James Ward8b390432022-08-12 20:48:56 +01004376 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004377
4378 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004379 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004380 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004381 else:
4382 excludes = [out_dtype]
4383 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004384 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004385
Kevin Cheng550ccc52021-03-03 11:21:43 -08004386 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004387
4388 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004389 def conv3dOp(
4390 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4391 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004392
4393 # IFM: NDHWC
4394 # Filter: ODHWI
4395 # OFM: NDHWC
4396
4397 d = (
4398 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004399 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004400 + padding[0]
4401 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004402 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004403 ) // strides[0] + 1
4404
4405 h = (
4406 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004407 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004408 + padding[2]
4409 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004410 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004411 ) // strides[1] + 1
4412
4413 w = (
4414 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004415 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004416 + padding[4]
4417 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004418 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004419 ) // strides[2] + 1
4420
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004421 if error_name == ErrorIf.ConvOutputShapeMismatch:
4422 choices = [1, 2, 3, 4]
4423 change = rng.choice(choices)
4424 # increment in multiples of stride to not hit non-integer error case
4425 if change in [1, 4]:
4426 d = d + (rng.choice(choices) * strides[0])
4427 if change in [2, 4]:
4428 h = h + (rng.choice(choices) * strides[1])
4429 if change in [3, 4]:
4430 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004431
Kevin Cheng1533b852021-09-01 12:51:58 -07004432 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4433
James Ward8b390432022-08-12 20:48:56 +01004434 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004435 # Pick some potentially correct output dtype if input type is incorrect
4436 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004437 else:
James Ward8b390432022-08-12 20:48:56 +01004438 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004439
4440 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004441 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004442 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004443 else:
4444 excludes = [out_dtype]
4445 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004446 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004447
4448 return ser.addOutput(ofm_shape, out_dtype)
4449
4450 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004451 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004452 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004453 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004454 # IFM: NHWC
4455 # Filter: HWCM
4456 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004457
Kevin Cheng550ccc52021-03-03 11:21:43 -08004458 h = (
4459 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004460 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004461 + padding[0]
4462 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004463 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004464 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004465
Kevin Cheng550ccc52021-03-03 11:21:43 -08004466 w = (
4467 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004468 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004469 + padding[2]
4470 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004471 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004472 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004473
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004474 if error_name == ErrorIf.ConvOutputShapeMismatch:
4475 choices = [1, 2, 3]
4476 change = rng.choice(choices)
4477 # increment in multiples of stride to not hit non-integer error case
4478 if change in [1, 3]:
4479 h = h + (rng.choice(choices) * strides[0])
4480 if change in [2, 3]:
4481 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004482
Eric Kunzee5e26762020-10-13 16:11:07 -07004483 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4484
James Ward8b390432022-08-12 20:48:56 +01004485 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004486 # Pick some potentially correct output dtype if input type is incorrect
4487 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004488 else:
James Ward8b390432022-08-12 20:48:56 +01004489 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004490
4491 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004492 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004493 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004494 else:
4495 excludes = [out_dtype]
4496 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004497 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004498
Kevin Cheng550ccc52021-03-03 11:21:43 -08004499 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004500
4501 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004502 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004503 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004504 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004505 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004506 h = 1
4507 w = 1
4508 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004509 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4510 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004511
4512 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004513 choices = [1, 2, 3]
4514 change = rng.choice(choices)
4515 # increment in multiples of stride to not hit non-integer error case
4516 if change in [1, 3]:
4517 h = h + (rng.choice(choices) * stride[0])
4518 if change in [2, 3]:
4519 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004520 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004521
4522 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004523 all_dtypes = [
4524 DType.INT8,
4525 DType.INT16,
4526 DType.INT32,
4527 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004528 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004529 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004530 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004531 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004532 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4533 outputDType = rng.choice(wrong_dtypes)
4534 else:
4535 outputDType = ifm.dtype
4536
4537 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004538
4539 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004540 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004541 # input: N, IC
4542 # filter: OC, IC
4543 # output: N, OC
4544
4545 output_shape = [input.shape[0], filter.shape[0]]
4546
James Ward8b390432022-08-12 20:48:56 +01004547 # Validated in arg_gen (also invalidated for ErrorIf)
4548 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004549
Kevin Cheng550ccc52021-03-03 11:21:43 -08004550 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004551
4552 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004553 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004554 # a: N, H, C
4555 # b: N, C, W
4556 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004557
Kevin Cheng2d60f002021-06-09 14:18:32 -07004558 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004559
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004560 if error_name == ErrorIf.WrongOutputType:
4561 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004562 incorrect_types = (
4563 DType.INT4,
4564 DType.INT8,
4565 DType.INT16,
4566 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004567 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004568 DType.FP16,
4569 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004570 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004571 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004572 incorrect_types = (
4573 DType.INT4,
4574 DType.INT8,
4575 DType.INT16,
4576 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004577 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004578 DType.FP16,
4579 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004580 )
James Ward24dbc422022-10-19 12:20:31 +01004581 elif (
4582 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4583 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004584 incorrect_types = (
4585 DType.INT4,
4586 DType.INT8,
4587 DType.INT16,
4588 DType.INT32,
4589 DType.INT48,
4590 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004591 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004592 elif error_name == ErrorIf.WrongInputType:
4593 # Pick some potentially correct output dtype if input type is incorrect
4594 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004595 else:
James Ward8b390432022-08-12 20:48:56 +01004596 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004597
Kevin Cheng550ccc52021-03-03 11:21:43 -08004598 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004599
4600 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004601 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004602 input1 = a[0]
4603 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004604
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004605 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004606 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004607 if not (
4608 # unable to concat tensors of different ranks
4609 error_name == ErrorIf.ConcatInputRankMismatch
4610 # unable to concat tensors along an invalid axis
4611 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004612 ):
4613 for tensor in remaining_inputs:
4614 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004615
Matthew Haddon01c359d2021-10-15 16:30:48 +01004616 if error_name == ErrorIf.ConcatShapeSumMismatch:
4617 output_shape[axis] += rng.integers(5, 10)
4618
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004619 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004620 all_dtypes = {
4621 DType.INT8,
4622 DType.INT16,
4623 DType.INT32,
4624 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004625 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004626 DType.FP16,
4627 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004628 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004629 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4630 outputDType = rng.choice(wrong_dtypes)
4631 else:
4632 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004633
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004634 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004635
4636 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004637 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004638
4639 output_shape = a.shape.copy()
4640
4641 for i in range(len(output_shape)):
4642 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4643
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004644 if error_name == ErrorIf.PadOutputShapeMismatch:
4645 bad_dim = rng.choice(range(len(output_shape)))
4646 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004647 elif error_name == ErrorIf.RankMismatch:
4648 output_shape = get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004649
Matthew Haddone807aae2021-10-11 18:12:58 +01004650 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004651 all_dtypes = [
4652 DType.INT8,
4653 DType.INT16,
4654 DType.INT32,
4655 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004656 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004657 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004658 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004659 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004660 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4661 outputDType = rng.choice(wrong_dtypes)
4662 else:
4663 outputDType = a.dtype
4664
4665 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004666
4667 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004668 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004669 output_shape = shape.copy()
4670
Matthew Haddone807aae2021-10-11 18:12:58 +01004671 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4672 for i in range(len(output_shape)):
4673 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4674
4675 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004676 all_dtypes = [
4677 DType.INT8,
4678 DType.INT16,
4679 DType.INT32,
4680 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004681 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004682 DType.FP16,
4683 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004684 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004685 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4686 outputDType = rng.choice(wrong_dtypes)
4687 else:
4688 outputDType = a.dtype
4689
4690 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004691
4692 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004693 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004694
Matthew Haddone807aae2021-10-11 18:12:58 +01004695 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004696 all_dtypes = [
4697 DType.INT8,
4698 DType.INT16,
4699 DType.INT32,
4700 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004701 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004702 DType.FP16,
4703 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004704 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004705 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004706 outputDType = rng.choice(wrong_dtypes)
4707 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004708 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004709
Luke Huttona4e48ca2023-02-22 11:53:48 +00004710 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004711 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004712 for index in range(len(output_shape)):
4713 if output_shape[index] <= 2:
4714 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4715 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004716 output_shape[index] = output_shape[index] + rng.choice(
4717 [-2, -1, 1, 2]
4718 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004719 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4720 output_shape = input.shape.copy()
4721 elif error_name == ErrorIf.RankMismatch:
4722 output_shape = get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004723
4724 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004725
4726 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004727 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004728
4729 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004730 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004731
4732 for i in range(len(output_shape)):
4733 output_shape[i] = a.shape[i] * multiples[i]
4734
Luke Huttona4e48ca2023-02-22 11:53:48 +00004735 if error_name == ErrorIf.RankMismatch:
4736 output_shape = get_rank_mismatch_shape(rng, output_shape)
4737
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004738 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004739 all_dtypes = [
4740 DType.INT8,
4741 DType.INT16,
4742 DType.INT32,
4743 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004744 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004745 DType.FP16,
4746 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004747 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004748 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4749 outputDType = rng.choice(wrong_dtypes)
4750 else:
4751 outputDType = a.dtype
4752
4753 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004754
4755 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004756 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004757 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004758
Kevin Cheng550ccc52021-03-03 11:21:43 -08004759 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004760
Luke Huttona4e48ca2023-02-22 11:53:48 +00004761 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004762 for i in range(len(output_shape)):
4763 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004764
Luke Huttona4e48ca2023-02-22 11:53:48 +00004765 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4766 for i in range(len(output_shape)):
4767 output_shape[i] += rng.integers(1, 10)
4768 elif error_name == ErrorIf.RankMismatch:
4769 output_shape = get_rank_mismatch_shape(rng, output_shape)
4770
Matthew Haddone807aae2021-10-11 18:12:58 +01004771 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004772 all_dtypes = [
4773 DType.INT8,
4774 DType.INT16,
4775 DType.INT32,
4776 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004777 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004778 DType.FP16,
4779 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004780 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004781 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4782 outputDType = rng.choice(wrong_dtypes)
4783 else:
4784 outputDType = a.dtype
4785
4786 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004787
4788 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004789 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004790 if error_name != ErrorIf.WrongRank:
4791 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004792 assert len(indices.shape) == 2
4793 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004794
Kevin Cheng77d0f762020-11-24 10:26:32 -08004795 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4796
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004797 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004798 all_dtypes = [
4799 DType.INT8,
4800 DType.INT16,
4801 DType.INT32,
4802 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004803 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004804 DType.FP16,
4805 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004806 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004807 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4808 outputDType = rng.choice(wrong_dtypes)
4809 else:
4810 outputDType = values.dtype
4811
4812 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004813
4814 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004815 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004816 if error_name != ErrorIf.WrongRank:
4817 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004818 assert len(indices.shape) == 2
4819 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004820 assert values_in.shape[0] == indices.shape[0] # N
4821 assert input.shape[1] == indices.shape[1] # W
4822 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004823
4824 output_shape = values_in.shape
4825
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004826 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004827 all_dtypes = [
4828 DType.INT8,
4829 DType.INT16,
4830 DType.INT32,
4831 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004832 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004833 DType.FP16,
4834 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004835 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004836 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4837 outputDType = rng.choice(wrong_dtypes)
4838 else:
4839 outputDType = values_in.dtype
4840
4841 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004842
4843 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004844 def tableOp(ser, rng, input, error_name=None):
4845 # Same shape as the input, dtype dependent on input dtype
4846 if error_name != ErrorIf.WrongInputType:
4847 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004848 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004849 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004850 wrong_dtypes = [
4851 DType.INT8,
4852 DType.INT16,
4853 DType.INT32,
4854 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004855 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004856 DType.FP16,
4857 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004858 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004859 wrong_dtypes.remove(output_dtype)
4860 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004861 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004862
4863 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004864 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004865 serializer,
4866 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004867 input,
4868 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004869 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004870 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004871 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004872 input_dtype,
4873 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004874 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004875 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004876 # Calculate OH, OW
4877 scale_y_n = scale[0]
4878 scale_y_d = scale[1]
4879 scale_x_n = scale[2]
4880 scale_x_d = scale[3]
4881 if error_name == ErrorIf.ScaleSmallerEqualZero:
4882 scale_y_n = max(scale_y_n, 1)
4883 scale_y_d = max(scale_y_d, 1)
4884 scale_x_n = max(scale_x_n, 1)
4885 scale_x_d = max(scale_x_d, 1)
4886
4887 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4888 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4889
4890 if error_name is not None:
4891 # Make sure the output tensor is valid, which can occur when
4892 # scale, offset or border have been changed for ERROR_IFs
4893 oh = max(oh, 1)
4894 ow = max(ow, 1)
4895 if error_name != ErrorIf.MaxDimExceeded:
4896 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4897 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4898
4899 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4900 choices = [1, 2, 3]
4901 change = rng.choice(choices)
4902 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4903 if change in [1, 3]:
4904 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4905 oh -= scale_y_d
4906 assert oh > 0 # Should have been caught in agResize
4907 else:
4908 oh += scale_y_d
4909 if change in [2, 3]:
4910 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4911 ow -= scale_x_d
4912 assert ow > 0 # Should have been caught in agResize
4913 else:
4914 ow += scale_x_d
4915
Matthew Haddon848efb42021-09-09 12:30:53 +01004916 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004917 output_dims = [
4918 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004919 oh,
4920 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004921 input.shape[0],
4922 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004923 elif error_name == ErrorIf.BatchMismatch:
4924 output_dims = [
4925 input.shape[0] + rng.integers(1, 10),
4926 oh,
4927 ow,
4928 input.shape[3],
4929 ]
4930 elif error_name == ErrorIf.ChannelMismatch:
4931 output_dims = [
4932 input.shape[0],
4933 oh,
4934 ow,
4935 input.shape[3] + rng.integers(1, 10),
4936 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004937 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004938 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004939
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004940 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004941
4942 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004943 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004944 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004945
4946 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004947 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004948 if error_name == ErrorIf.ConvOutputShapeMismatch:
4949 choices = [1, 2, 3]
4950 change = rng.choice(choices)
4951 if change in [1, 3]:
4952 output_shape[1] = output_shape[1] + rng.choice(choices)
4953 if change in [2, 3]:
4954 output_shape[2] = output_shape[2] + rng.choice(choices)
4955
James Ward8b390432022-08-12 20:48:56 +01004956 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004957 # Pick some potentially correct output dtype if input type is incorrect
4958 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004959 else:
James Ward8b390432022-08-12 20:48:56 +01004960 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004961
4962 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004963 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004964 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004965 else:
4966 excludes = [out_dtype]
4967 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004968 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004969
Kevin Cheng550ccc52021-03-03 11:21:43 -08004970 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00004971
4972 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00004973 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
4974 outputs = []
4975
4976 assert ifm1.dtype == ifm2.dtype
4977 input_dtype = ifm1.dtype
4978
4979 if error_name != ErrorIf.FFTInputShapeMismatch:
4980 assert ifm1.shape == ifm2.shape
4981
4982 input_shape = ifm1.shape
4983 if error_name != ErrorIf.WrongRank:
4984 assert len(input_shape) == 3
4985
4986 output_shape = input_shape.copy()
4987 output_dtype = input_dtype
4988
4989 if error_name == ErrorIf.WrongOutputType:
4990 excludes = [DType.FP32]
4991 wrong_dtypes = list(usableDTypes(excludes=excludes))
4992 output_dtype = rng.choice(wrong_dtypes)
4993 elif error_name == ErrorIf.BatchMismatch:
4994 output_shape[0] += rng.integers(1, 10)
4995 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4996 modify_dim = rng.choice([1, 2])
4997 output_shape[modify_dim] += rng.integers(1, 10)
4998
4999 outputs.append(serializer.addOutput(output_shape, output_dtype))
5000 outputs.append(serializer.addOutput(output_shape, output_dtype))
5001 return outputs
5002
5003 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005004 def rfft2dOp(serializer, rng, value, error_name=None):
5005 outputs = []
5006
5007 input_shape = value.shape
5008 if error_name != ErrorIf.WrongRank:
5009 assert len(input_shape) == 3
5010
5011 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5012
5013 output_dtype = value.dtype
5014 if error_name == ErrorIf.WrongOutputType:
5015 excludes = [DType.FP32]
5016 wrong_dtypes = list(usableDTypes(excludes=excludes))
5017 output_dtype = rng.choice(wrong_dtypes)
5018 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005019 output_shape[0] += rng.integers(1, 10)
5020 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5021 modify_dim = rng.choice([1, 2])
5022 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005023
5024 outputs.append(serializer.addOutput(output_shape, output_dtype))
5025 outputs.append(serializer.addOutput(output_shape, output_dtype))
5026 return outputs