blob: 5f9e2c10b5417ca8831a302d172a3877964043a5 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Jeremy Johnson05c711e2022-12-12 18:00:41 +000017from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010018from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010019from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010020from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000021from tosa.DType import DType
22from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010023
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010026 # Maximum rank of tensor supported by test generator.
27 TOSA_TENSOR_MAX_RANK = 6
28
Eric Kunzee5e26762020-10-13 16:11:07 -070029 def __init__(self, args):
30 self.args = args
31 self.basePath = args.output_dir
32 self.random_seed = args.random_seed
33 self.ser = None
34 self.rng = np.random.default_rng(self.random_seed)
35 self.createDynamicOpLists()
36 self.initOpListDefaults()
37 self.quantGen = TosaQuantGen()
38 # Force makeShape to do a specific starting shape
39 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010040 # Work out floating point range
41 self.random_fp_low = min(args.tensor_fp_value_range)
42 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070043
44 def createSerializer(self, opName, testPath):
45 self.testPath = os.path.join(opName, testPath)
46
47 fullPath = os.path.join(self.basePath, self.testPath)
48 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010049 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070050
51 def getSerializer(self):
52 return self.ser
53
54 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080055 with open(
56 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
57 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070058 fd.write(self.ser.serialize())
59
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
61 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070062
Matthew Haddon74567092021-07-16 15:38:20 +010063 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000064 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010065 seed = self.random_seed + 1
66 self.rng = np.random.default_rng(seed)
67
Eric Kunzee5e26762020-10-13 16:11:07 -070068 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070069 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070071 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070072 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070073 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070074 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010075 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
76 elif dtype == DType.UINT8:
77 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070078 elif dtype == DType.INT16:
79 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010080 elif dtype == DType.UINT16:
81 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070082 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080083 return np.int32(
84 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
85 )
Eric Kunzee5e26762020-10-13 16:11:07 -070086 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080087 return np.int64(
88 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
89 )
James Ward8b390432022-08-12 20:48:56 +010090 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010091 return np.float16(
92 self.rng.uniform(
93 low=self.random_fp_low, high=self.random_fp_high, size=shape
94 )
95 )
James Ward24dbc422022-10-19 12:20:31 +010096 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010097 f32_tensor = np.float32(
98 self.rng.uniform(
99 low=self.random_fp_low, high=self.random_fp_high, size=shape
100 )
101 )
James Ward24dbc422022-10-19 12:20:31 +0100102 # Floor the last 16 bits of each f32 value
103 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100104 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100105 return np.float32(
106 self.rng.uniform(
107 low=self.random_fp_low, high=self.random_fp_high, size=shape
108 )
109 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800111 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700112
Kevin Cheng989cb052021-04-28 16:29:44 -0700113 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 placeholders = []
115
Kevin Cheng989cb052021-04-28 16:29:44 -0700116 assert len(shape_list) == len(dtype_list)
117
118 for idx, shape in enumerate(shape_list):
119 arr = self.getRandTensor(shape, dtype_list[idx])
120 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 return placeholders
123
Kevin Cheng989cb052021-04-28 16:29:44 -0700124 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700125 consts = []
126
Kevin Cheng989cb052021-04-28 16:29:44 -0700127 assert len(shape_list) == len(dtype_list)
128
129 for idx, shape in enumerate(shape_list):
130 arr = self.getRandTensor(shape, dtype_list[idx])
131 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132
133 return consts
134
135 def makeShape(self, rank):
136 if self.targetted_shape:
137 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800138 return np.int32(
139 self.rng.integers(
140 low=self.args.tensor_shape_range[0],
141 high=self.args.tensor_shape_range[1],
142 size=rank,
143 )
144 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
146 def setTargetShape(self, shape):
147 self.targetted_shape = shape
148
149 def randInt(self, low=0, high=256):
150 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
151
152 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100153 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100154 return np.float32(
155 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
156 )
James Ward8b390432022-08-12 20:48:56 +0100157 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100158 return np.float16(
159 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
160 )
James Ward24dbc422022-10-19 12:20:31 +0100161 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 rand_f32 = np.float32(
163 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
164 )
James Ward24dbc422022-10-19 12:20:31 +0100165 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 elif dtype == DType.BOOL:
167 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700168 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700170 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100172 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 elif dtype == DType.INT16:
174 low, high = (-32768, 32768)
175 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800176 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800178 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 # Special size
180 return np.int64(self.rng.integers(low, high, size=1))[0]
181 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800182 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700183
184 return np.int32(self.rng.integers(low, high, size=1))[0]
185
186 def shapeStr(self, shape):
187
188 sStr = []
189 # Convert to strings
190 for i in shape:
191 sStr.append(str(i))
192
Kevin Cheng550ccc52021-03-03 11:21:43 -0800193 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700194
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100195 def typeStr(self, dtype):
196 if isinstance(dtype, list) or isinstance(dtype, tuple):
197 assert len(dtype) >= 2
198 strs = [self.typeStr(t) for t in dtype]
199 # Limit types to the first 2 as the 3rd is the accumulator
200 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100202 if dtype in DTYPE_ATTRIBUTES:
203 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700204 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100205 raise Exception(
206 "Unknown dtype, cannot convert to string: {}".format(dtype)
207 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100209 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100210 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100211 if dtype in DTYPE_ATTRIBUTES:
212 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 # Argument generators
217 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
218 # Where the string descriptor is used to generate the test name and
219 # The build_fcn_arg_list is expanded and passed to the operator test
220 # build function
221
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100222 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
223 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
224
Matthew Haddon848efb42021-09-09 12:30:53 +0100225 # build_placeholder returns an int, ABS/other ops does not
226 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100228 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000229 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000230 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100231 return result_tens
232
233 # Ensure new output type has correct qinfo
234 if error_name == ErrorIf.WrongOutputType:
235 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000236 qinfo = [
237 TosaQuantGen.getZeroPoint(self, a.dtype),
238 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
239 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100240
241 # Invalidate Input/Output list for error if checks.
242 input_list = [a.name]
243 output_list = [result_tens.name]
244 pCount, cCount = op["operands"]
245 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000246 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
247 self, error_name, input_list, output_list
248 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100249
Les Bell729b0352021-11-24 10:28:21 +0000250 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100251 self.ser,
252 validator_fcns,
253 error_name,
254 op=op,
255 input_dtype=a.dtype,
256 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000257 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000258 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100259 input_list=input_list,
260 output_list=output_list,
261 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000262 ):
263 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100264
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000265 attr = None
266 if op["op"] == Op.NEGATE:
267 attr = ts.TosaSerializerAttribute()
268 attr.NegateAttribute(qinfo[0], qinfo[1])
269
270 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271 return result_tens
272
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100273 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000274 result_tens = OutputShaper.binaryBroadcastOp(
275 self.ser, self.rng, a, b, error_name
276 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100277
278 # Invalidate Input/Output list for error if checks.
279 input_list = [a.name, b.name]
280 output_list = [result_tens.name]
281 pCount, cCount = op["operands"]
282 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000283 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
284 self, error_name, input_list, output_list
285 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100286
Les Bell729b0352021-11-24 10:28:21 +0000287 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100288 self.ser,
289 validator_fcns,
290 error_name,
291 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000292 input1=a,
293 input2=b,
294 input_dtype=a.dtype,
295 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000296 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100297 input_list=input_list,
298 output_list=output_list,
299 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000300 ):
301 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100302
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000303 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 return result_tens
305
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100306 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700307 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000308 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700309 return result_tens
310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000311 def build_arithmetic_right_shift(
312 self, op, a, b, round, validator_fcns=None, error_name=None
313 ):
314 result_tens = OutputShaper.binaryBroadcastOp(
315 self.ser, self.rng, a, b, error_name
316 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100317
318 # Invalidate Input/Output list for error if checks.
319 input_list = [a.name, b.name]
320 output_list = [result_tens.name]
321 pCount, cCount = op["operands"]
322 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000323 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
324 self, error_name, input_list, output_list
325 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100326
Les Bell729b0352021-11-24 10:28:21 +0000327 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100328 self.ser,
329 validator_fcns,
330 error_name,
331 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000332 input1=a,
333 input2=b,
334 input_dtype=a.dtype,
335 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000336 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100337 input_list=input_list,
338 output_list=output_list,
339 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000340 ):
341 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800342
343 attr = ts.TosaSerializerAttribute()
344 attr.ArithmeticRightShiftAttribute(round)
345
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000346 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800347 return result_tens
348
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100349 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000350 result_tens = OutputShaper.binaryBroadcastOp(
351 self.ser, self.rng, a, b, error_name
352 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 # Special for multiply:
355 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100356 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700357 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100358 if error_name == ErrorIf.WrongOutputType:
359 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
360 outputDType = self.rng.choice(all_dtypes)
361 result_tens.setDtype(outputDType)
362
363 # Invalidate Input/Output list for error if checks.
364 input_list = [a.name, b.name]
365 output_list = [result_tens.name]
366 pCount, cCount = op["operands"]
367 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000368 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
369 self, error_name, input_list, output_list
370 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100371
Les Bell729b0352021-11-24 10:28:21 +0000372 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100373 self.ser,
374 validator_fcns,
375 error_name,
376 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000377 input1=a,
378 input2=b,
379 input_dtype=a.dtype,
380 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000381 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100382 input_list=input_list,
383 output_list=output_list,
384 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000385 ):
386 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700387
Kevin Chengaee1fac2020-11-11 13:54:06 -0800388 attr = ts.TosaSerializerAttribute()
389 attr.MulAttribute(shift)
390
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000391 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 return result_tens
393
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100394 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
395 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700396
Kevin Chengfe392ce2021-10-18 21:51:55 +0000397 attr = ts.TosaSerializerAttribute()
398 attr.TableAttribute(table)
399
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100400 # Invalidate Input/Output list for error if checks.
401 input_list = [a.name]
402 output_list = [result_tens.name]
403 pCount, cCount = op["operands"]
404 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000405 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
406 self, error_name, input_list, output_list
407 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100408
Les Bell729b0352021-11-24 10:28:21 +0000409 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 self.ser,
411 validator_fcns,
412 error_name,
413 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000414 input_shape=a.shape,
415 input_dtype=a.dtype,
416 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000417 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100418 input_list=input_list,
419 output_list=output_list,
420 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000421 ):
422 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100423
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
426 return result_tens
427
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100428 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
429 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
430
431 # Invalidate Input/Output list for error if checks.
432 input_list = [cond.name, a.name, b.name]
433 output_list = [result_tens.name]
434 pCount, cCount = op["operands"]
435 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
437 self, error_name, input_list, output_list
438 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439
Les Bell729b0352021-11-24 10:28:21 +0000440 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100441 self.ser,
442 validator_fcns,
443 error_name,
444 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000445 input1=cond,
446 input2=a,
447 input3=b,
448 input_shape=a.shape,
449 input_dtype=a.dtype,
450 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000451 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452 input_list=input_list,
453 output_list=output_list,
454 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000455 ):
456 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000458 self.ser.addOperator(
459 op["op"],
460 input_list,
461 output_list,
462 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700463 return result_tens
464
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000466 result_tens = OutputShaper.binaryComparisonOp(
467 self.ser, self.rng, a, b, error_name
468 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100469
470 # Invalidate Input/Output list for error if checks.
471 input_list = [a.name, b.name]
472 output_list = [result_tens.name]
473 pCount, cCount = op["operands"]
474 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000475 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
476 self, error_name, input_list, output_list
477 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478
Les Bell729b0352021-11-24 10:28:21 +0000479 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480 self.ser,
481 validator_fcns,
482 error_name,
483 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000484 input1=a,
485 input2=b,
486 input_shape=a.shape,
487 input_dtype=a.dtype,
488 output_shape=result_tens.shape,
489 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000490 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491 input_list=input_list,
492 output_list=output_list,
493 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000494 ):
495 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 self.ser.addOperator(
498 op["op"],
499 input_list,
500 output_list,
501 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700502 return result_tens
503
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100504 def build_argmax(self, op, a, axis, validator_fcns, error_name):
505 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
506
507 # Invalidate Input/Output list for error if checks.
508 input_list = [a.name]
509 output_list = [result_tens.name]
510 pCount, cCount = op["operands"]
511 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000512 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
513 self, error_name, input_list, output_list
514 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100515
Les Bell729b0352021-11-24 10:28:21 +0000516 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100517 self.ser,
518 validator_fcns,
519 error_name,
520 op=op,
521 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 input_shape=a.shape,
523 input_dtype=a.dtype,
524 output_shape=result_tens.shape,
525 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000526 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527 input_list=input_list,
528 output_list=output_list,
529 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000530 ):
531 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
533 attr = ts.TosaSerializerAttribute()
534 attr.AxisAttribute(axis)
535
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 return result_tens
538
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 def build_pool2d(
540 self,
541 op,
542 input,
James Ward8b390432022-08-12 20:48:56 +0100543 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 stride,
545 pad,
546 kernel,
547 validator_fcns=None,
548 error_name=None,
549 qinfo=None,
550 ):
551 result_tens = OutputShaper.pool2dOp(
552 self.ser, self.rng, input, kernel, stride, pad, error_name
553 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100554
555 # Ensure new output type has correct qinfo
556 if error_name == ErrorIf.WrongInputType:
557 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000558 qinfo = [
559 TosaQuantGen.getZeroPoint(self, input.dtype),
560 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
561 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100562
563 # Invalidate Input/Output list for error if checks.
564 input_list = [input.name]
565 output_list = [result_tens.name]
566 pCount, cCount = op["operands"]
567 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000568 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
569 self, error_name, input_list, output_list
570 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100571
Les Bell729b0352021-11-24 10:28:21 +0000572 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100573 self.ser,
574 validator_fcns,
575 error_name,
576 op=op,
577 input_shape=input.shape,
578 input_dtype=input.dtype,
579 output_shape=result_tens.shape,
580 output_dtype=result_tens.dtype,
581 kernel=kernel,
582 stride=stride,
583 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000585 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100586 input_list=input_list,
587 output_list=output_list,
588 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000589 ):
590 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700591
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000592 if qinfo is None:
593 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700594
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000595 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100596 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000597
598 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700599 return result_tens
600
James Ward8b390432022-08-12 20:48:56 +0100601 def build_maxpool2d(
602 self,
603 op,
604 input,
605 stride,
606 pad,
607 kernel,
608 validator_fcns=None,
609 error_name=None,
610 qinfo=None,
611 ):
612 # Same as build_pool2d but manually sets accum_dtype value
613 # (maxpool has no accum_dtype)
614 return self.build_pool2d(
615 op,
616 input,
617 DType.UNKNOWN,
618 stride,
619 pad,
620 kernel,
621 validator_fcns,
622 error_name,
623 qinfo,
624 )
625
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 def build_conv2d(
627 self,
628 op,
629 ifm,
630 filter,
631 bias,
James Ward8b390432022-08-12 20:48:56 +0100632 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000633 strides,
634 padding,
635 dilations,
636 validator_fcns=None,
637 error_name=None,
638 qinfo=None,
639 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800640 assert len(padding) == 4
641 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100642 self.ser,
643 self.rng,
644 ifm,
645 filter,
646 accum_dtype,
647 strides,
648 padding,
649 dilations,
650 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000651 )
652
653 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
655 DType.INT8,
656 DType.UINT8,
657 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000658 qinfo = [
659 TosaQuantGen.getZeroPoint(self, ifm.dtype),
660 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
661 ]
Les Bell0e027d42021-11-09 14:42:14 +0000662
663 # Invalidate Input/Output list for error_if checks.
664 input_list = [ifm.name, filter.name, bias.name]
665 output_list = [result_tens.name]
666 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
668 self, error_name, input_list, output_list
669 )
Les Bell0e027d42021-11-09 14:42:14 +0000670
Les Bell729b0352021-11-24 10:28:21 +0000671 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000672 self.ser,
673 validator_fcns,
674 error_name,
675 op=op,
676 input_dtype=ifm.dtype,
677 weight_dtype=filter.dtype,
678 output_dtype=result_tens.dtype,
679 qinfo=qinfo,
680 input_list=input_list,
681 num_operands=num_operands,
682 output_list=output_list,
683 pad=padding,
684 stride=strides,
685 dilation=dilations,
686 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100687 weight_shape=filter.shape,
688 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000689 ):
690 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
692 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000693 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000695 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700696 return result_tens
697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 def build_conv3d(
699 self,
700 op,
701 ifm,
702 filter,
703 bias,
James Ward8b390432022-08-12 20:48:56 +0100704 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 strides,
706 padding,
707 dilations,
708 validator_fcns=None,
709 error_name=None,
710 qinfo=None,
711 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700712 assert len(padding) == 6
713 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100714 self.ser,
715 self.rng,
716 ifm,
717 filter,
718 accum_dtype,
719 strides,
720 padding,
721 dilations,
722 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000723 )
724
725 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000726 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
727 DType.INT8,
728 DType.UINT8,
729 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000730 qinfo = [
731 TosaQuantGen.getZeroPoint(self, ifm.dtype),
732 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
733 ]
Les Bell0e027d42021-11-09 14:42:14 +0000734
735 # Invalidate Input/Output list for error_if checks.
736 input_list = [ifm.name, filter.name, bias.name]
737 output_list = [result_tens.name]
738 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000739 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
740 self, error_name, input_list, output_list
741 )
Les Bell0e027d42021-11-09 14:42:14 +0000742
Les Bell729b0352021-11-24 10:28:21 +0000743 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000744 self.ser,
745 validator_fcns,
746 error_name,
747 op=op,
748 input_dtype=ifm.dtype,
749 weight_dtype=filter.dtype,
750 output_dtype=result_tens.dtype,
751 qinfo=qinfo,
752 input_list=input_list,
753 num_operands=num_operands,
754 output_list=output_list,
755 pad=padding,
756 stride=strides,
757 dilation=dilations,
758 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100759 weight_shape=filter.shape,
760 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000761 ):
762 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700763
764 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000765 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700766
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000767 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700768 return result_tens
769
Kevin Cheng550ccc52021-03-03 11:21:43 -0800770 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000771 self,
772 op,
773 ifm,
774 filter,
775 bias,
James Ward8b390432022-08-12 20:48:56 +0100776 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000777 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700778 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000779 output_shape,
780 validator_fcns=None,
781 error_name=None,
782 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800783 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700784 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000785 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100786 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 )
Les Bell0e027d42021-11-09 14:42:14 +0000788
789 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
791 DType.INT8,
792 DType.UINT8,
793 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000794 qinfo = [
795 TosaQuantGen.getZeroPoint(self, ifm.dtype),
796 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
797 ]
Les Bell0e027d42021-11-09 14:42:14 +0000798
799 # Invalidate Input/Output list for error_if checks.
800 input_list = [ifm.name, filter.name, bias.name]
801 output_list = [result_tens.name]
802 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
804 self, error_name, input_list, output_list
805 )
Les Bell0e027d42021-11-09 14:42:14 +0000806
Les Bell729b0352021-11-24 10:28:21 +0000807 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000808 self.ser,
809 validator_fcns,
810 error_name,
811 op=op,
812 input_dtype=ifm.dtype,
813 weight_dtype=filter.dtype,
814 output_dtype=result_tens.dtype,
815 qinfo=qinfo,
816 input_list=input_list,
817 num_operands=num_operands,
818 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700819 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000820 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000821 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100822 weight_shape=filter.shape,
823 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000824 ):
825 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700826
827 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000828 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700829
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000830 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700831 return result_tens
832
Kevin Cheng550ccc52021-03-03 11:21:43 -0800833 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 self,
835 op,
836 ifm,
837 filter,
838 bias,
James Ward8b390432022-08-12 20:48:56 +0100839 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000840 strides,
841 padding,
842 dilations,
843 validator_fcns=None,
844 error_name=None,
845 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800846 ):
847 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100848 self.ser,
849 self.rng,
850 ifm,
851 filter,
852 accum_dtype,
853 strides,
854 padding,
855 dilations,
856 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000857 )
858
859 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000860 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
861 DType.INT8,
862 DType.UINT8,
863 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000864 qinfo = [
865 TosaQuantGen.getZeroPoint(self, ifm.dtype),
866 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
867 ]
Les Bell0e027d42021-11-09 14:42:14 +0000868
869 # Invalidate Input/Output list for error_if checks.
870 input_list = [ifm.name, filter.name, bias.name]
871 output_list = [result_tens.name]
872 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000873 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
874 self, error_name, input_list, output_list
875 )
Les Bell0e027d42021-11-09 14:42:14 +0000876
Les Bell729b0352021-11-24 10:28:21 +0000877 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000878 self.ser,
879 validator_fcns,
880 error_name,
881 op=op,
882 input_dtype=ifm.dtype,
883 weight_dtype=filter.dtype,
884 output_dtype=result_tens.dtype,
885 qinfo=qinfo,
886 input_list=input_list,
887 num_operands=num_operands,
888 output_list=output_list,
889 pad=padding,
890 stride=strides,
891 dilation=dilations,
892 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100893 weight_shape=filter.shape,
894 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000895 ):
896 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700897
898 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000899 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700900
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000901 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700902 return result_tens
903
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000904 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100905 self,
906 op,
907 ifm,
908 filter,
909 bias,
910 accum_dtype,
911 validator_fcns=None,
912 error_name=None,
913 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000914 ):
915 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100916 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000917 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100918
919 # Invalidate Input/Output list for error if checks.
920 input_list = [ifm.name, filter.name, bias.name]
921 output_list = [result_tens.name]
922 pCount, cCount = op["operands"]
923 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000924 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
925 self, error_name, input_list, output_list
926 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100927
Les Bell729b0352021-11-24 10:28:21 +0000928 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100929 self.ser,
930 validator_fcns,
931 error_name,
932 op=op,
933 input_shape=ifm.shape,
934 input_dtype=ifm.dtype,
935 weight_dtype=filter.dtype,
936 output_shape=result_tens.shape,
937 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000938 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000939 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100940 input_list=input_list,
941 output_list=output_list,
942 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100943 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000944 ):
945 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700946
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000947 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000948 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000949
950 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700951 return result_tens
952
James Ward8b390432022-08-12 20:48:56 +0100953 def build_matmul(
954 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
955 ):
956 result_tens = OutputShaper.matmulOp(
957 self.ser, self.rng, a, b, accum_dtype, error_name
958 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100959
960 # Invalidate Input/Output list for error if checks.
961 input_list = [a.name, b.name]
962 output_list = [result_tens.name]
963 pCount, cCount = op["operands"]
964 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000965 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
966 self, error_name, input_list, output_list
967 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100968
Les Bell729b0352021-11-24 10:28:21 +0000969 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100970 self.ser,
971 validator_fcns,
972 error_name,
973 op=op,
974 input_shape=a.shape,
975 input_dtype=a.dtype,
976 input2_shape=b.shape,
977 input2_dtype=b.dtype,
978 output_shape=result_tens.shape,
979 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000980 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000981 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100982 input_list=input_list,
983 output_list=output_list,
984 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100985 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000986 ):
987 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100988
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000989 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000990 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000991
992 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700993 return result_tens
994
Matthew Haddond6ce7252021-09-29 15:35:44 +0100995 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
996 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
997
998 # Invalidate Input/Output list for error if checks.
999 input_list = [a.name]
1000 output_list = [result_tens.name]
1001 pCount, cCount = op["operands"]
1002 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001003 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1004 self, error_name, input_list, output_list
1005 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001006
Les Bell729b0352021-11-24 10:28:21 +00001007 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001008 self.ser,
1009 validator_fcns,
1010 error_name,
1011 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001012 axis=axis,
1013 input_shape=a.shape,
1014 output_shape=result_tens.shape,
1015 input_dtype=a.dtype,
1016 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001017 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001018 input_list=input_list,
1019 output_list=output_list,
1020 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001021 ):
1022 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001023
1024 attr = ts.TosaSerializerAttribute()
1025 attr.AxisAttribute(axis)
1026
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001027 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001028 return result_tens
1029
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001030 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1031 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001032
Jeremy Johnson18e26662021-07-22 16:15:29 +01001033 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001034
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001035 if error_name == ErrorIf.MaxSmallerMin:
1036 # Make sure the numbers are different to invoke this error
1037 while v[0] == v[1]:
1038 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1039 max_val = min(v)
1040 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001041 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001042 max_val = max(v)
1043 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001044
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001045 # Invalidate Input/Output list for error if checks.
1046 input_list = [a.name]
1047 output_list = [result_tens.name]
1048 pCount, cCount = op["operands"]
1049 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001050 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1051 self, error_name, input_list, output_list
1052 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001053
Les Bell729b0352021-11-24 10:28:21 +00001054 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001055 self.ser,
1056 validator_fcns,
1057 error_name,
1058 op=op,
1059 max_val=max_val,
1060 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001061 input_shape=a.shape,
1062 output_shape=result_tens.shape,
1063 input_dtype=a.dtype,
1064 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001065 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001066 input_list=input_list,
1067 output_list=output_list,
1068 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001069 ):
1070 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001071
1072 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001073 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1074 if a.dtype == DType.FP16:
1075 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1076 min_val = min_val.astype(np.float32)
1077 max_val = max_val.astype(np.float32)
1078
1079 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001080 else:
James Ward34071252022-12-07 15:48:47 +00001081 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001082
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001083 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001084 return result_tens
1085
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001086 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1087 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001088 attr = ts.TosaSerializerAttribute()
1089
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001090 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001091
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001092 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001093 return result_tens
1094
1095 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001096 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1097 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001098
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001099 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001100 return result_tens
1101
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001102 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1103 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1104
1105 # Invalidate Input/Output list for error if checks.
1106 input_list = [a.name]
1107 output_list = [result_tens.name]
1108 pCount, cCount = op["operands"]
1109 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001110 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1111 self, error_name, input_list, output_list
1112 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001113
Les Bell729b0352021-11-24 10:28:21 +00001114 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001115 self.ser,
1116 validator_fcns,
1117 error_name,
1118 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001119 input_shape=a.shape,
1120 output_shape=result_tens.shape,
1121 input_dtype=a.dtype,
1122 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001123 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001124 input_list=input_list,
1125 output_list=output_list,
1126 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001127 ):
1128 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001129
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001130 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001131 return result_tens
1132
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001133 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1134 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1135
1136 # Invalidate Input/Output list for error if checks.
1137 input_list = [a.name]
1138 output_list = [result_tens.name]
1139 pCount, cCount = op["operands"]
1140 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1142 self, error_name, input_list, output_list
1143 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001144
Les Bell729b0352021-11-24 10:28:21 +00001145 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001146 self.ser,
1147 validator_fcns,
1148 error_name,
1149 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001150 input_shape=a.shape,
1151 output_shape=result_tens.shape,
1152 input_dtype=a.dtype,
1153 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001154 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001155 input_list=input_list,
1156 output_list=output_list,
1157 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001158 ):
1159 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001160
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001161 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001162 return result_tens
1163
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001164 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1165 if error_name != ErrorIf.WrongInputType:
1166 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001167
1168 # To store variable length list of input tensors we need to store axis along with it
1169 axis = a[-1]
1170 a = a[:-1]
1171
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001172 result_tens = OutputShaper.concatOp(
1173 self.ser, self.rng, axis, *a, error_name=error_name
1174 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001175
Matthew Haddon818ab902021-07-27 09:12:49 +01001176 input_tensor_names = []
1177 for tensor in a:
1178 input_tensor_names.append(tensor.name)
1179
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001180 # Invalidate Input/Output list for error if checks.
1181 input_list = input_tensor_names
1182 output_list = [result_tens.name]
1183 pCount, cCount = op["operands"]
1184 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001185 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1186 self, error_name, input_list, output_list
1187 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001188
Les Bell729b0352021-11-24 10:28:21 +00001189 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001190 self.ser,
1191 validator_fcns,
1192 error_name,
1193 op=op,
1194 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001195 input_shape=a[0].shape,
1196 output_shape=result_tens.shape,
1197 input_dtype=a[0].dtype,
1198 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001199 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001200 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001201 input_list=input_list,
1202 output_list=output_list,
1203 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001204 ):
1205 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001206
1207 attr = ts.TosaSerializerAttribute()
1208 attr.AxisAttribute(axis)
1209
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001210 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001211 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001212
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001213 def build_pad(
1214 self,
1215 op,
1216 a,
1217 padding,
1218 pad_const_int,
1219 pad_const_float,
1220 validator_fcns=None,
1221 error_name=None,
1222 qinfo=None,
1223 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001224 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001225
Kevin Chengfe392ce2021-10-18 21:51:55 +00001226 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001227 attr.PadAttribute(
1228 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1229 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001230
Matthew Haddone807aae2021-10-11 18:12:58 +01001231 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001232 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001233 output_list = [result_tens.name]
1234 pCount, cCount = op["operands"]
1235 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001236 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1237 self, error_name, input_list, output_list
1238 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001239
Les Bell729b0352021-11-24 10:28:21 +00001240 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001241 self.ser,
1242 validator_fcns,
1243 error_name,
1244 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001245 input_shape=a.shape,
1246 output_shape=result_tens.shape,
1247 input_dtype=a.dtype,
1248 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001249 pad=padding,
1250 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001251 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001252 input_list=input_list,
1253 output_list=output_list,
1254 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001255 ):
1256 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001257
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001258 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001259 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001260
Matthew Haddone807aae2021-10-11 18:12:58 +01001261 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001262 result_tens = OutputShaper.reshapeOp(
1263 self.ser, self.rng, a, newShape, error_name
1264 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001265
1266 # Invalidate Input/Output list for error if checks.
1267 input_list = [a.name]
1268 output_list = [result_tens.name]
1269 pCount, cCount = op["operands"]
1270 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001271 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1272 self, error_name, input_list, output_list
1273 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001274
Les Bell729b0352021-11-24 10:28:21 +00001275 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001276 self.ser,
1277 validator_fcns,
1278 error_name,
1279 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001280 input_shape=a.shape,
1281 output_shape=result_tens.shape,
1282 input_dtype=a.dtype,
1283 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001284 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001285 input_list=input_list,
1286 output_list=output_list,
1287 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001288 ):
1289 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001290
1291 attr = ts.TosaSerializerAttribute()
1292 attr.ReshapeAttribute(newShape)
1293
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001294 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001295 return result_tens
1296
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001297 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1298 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1299
1300 # Invalidate Input/Output list for error if checks.
1301 input_list = [a.name]
1302 output_list = [result_tens.name]
1303 pCount, cCount = op["operands"]
1304 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001305 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1306 self, error_name, input_list, output_list
1307 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001308
Les Bell729b0352021-11-24 10:28:21 +00001309 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001310 self.ser,
1311 validator_fcns,
1312 error_name,
1313 op=op,
1314 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001315 input_shape=a.shape,
1316 output_shape=result_tens.shape,
1317 input_dtype=a.dtype,
1318 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001319 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001320 input_list=input_list,
1321 output_list=output_list,
1322 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001323 ):
1324 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001325
1326 attr = ts.TosaSerializerAttribute()
1327 attr.AxisAttribute(axis)
1328
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001329 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001330 return result_tens
1331
Matthew Haddone807aae2021-10-11 18:12:58 +01001332 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1333 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001334
Kevin Chengfe392ce2021-10-18 21:51:55 +00001335 attr = ts.TosaSerializerAttribute()
1336 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001337
Matthew Haddone807aae2021-10-11 18:12:58 +01001338 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001339 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001340 output_list = [result_tens.name]
1341 pCount, cCount = op["operands"]
1342 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001343 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1344 self, error_name, input_list, output_list
1345 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001346
Les Bell729b0352021-11-24 10:28:21 +00001347 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001348 self.ser,
1349 validator_fcns,
1350 error_name,
1351 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001352 input_shape=a.shape,
1353 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001354 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 input_dtype=a.dtype,
1356 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001357 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001358 input_list=input_list,
1359 output_list=output_list,
1360 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001361 ):
1362 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001363
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001364 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001365 return result_tens
1366
Matthew Haddone807aae2021-10-11 18:12:58 +01001367 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001368 result_tens = OutputShaper.sliceOp(
1369 self.ser, self.rng, a, start, size, error_name
1370 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001371
1372 # Invalidate Input/Output list for error if checks.
1373 input_list = [a.name]
1374 output_list = [result_tens.name]
1375 pCount, cCount = op["operands"]
1376 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001377 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1378 self, error_name, input_list, output_list
1379 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001380
Les Bell729b0352021-11-24 10:28:21 +00001381 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001382 self.ser,
1383 validator_fcns,
1384 error_name,
1385 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001386 input_shape=a.shape,
1387 output_shape=result_tens.shape,
1388 input_dtype=a.dtype,
1389 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001390 start=start,
1391 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001392 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001393 input_list=input_list,
1394 output_list=output_list,
1395 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001396 ):
1397 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001398
1399 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001400 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001401
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001402 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001403 return result_tens
1404
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001405 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1406 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1407
1408 # Invalidate Input/Output list for error if checks.
1409 input_list = [a.name]
1410 output_list = [result_tens.name]
1411 pCount, cCount = op["operands"]
1412 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001413 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1414 self, error_name, input_list, output_list
1415 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001416
Les Bell729b0352021-11-24 10:28:21 +00001417 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418 self.ser,
1419 validator_fcns,
1420 error_name,
1421 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001422 input_shape=a.shape,
1423 output_shape=result_tens.shape,
1424 input_dtype=a.dtype,
1425 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001426 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001427 input_list=input_list,
1428 output_list=output_list,
1429 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001430 ):
1431 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001432
1433 attr = ts.TosaSerializerAttribute()
1434 attr.TileAttribute(multiples)
1435
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001436 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001437 return result_tens
1438
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001439 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001440
1441 # Create a new indicies tensor
1442 # here with data that doesn't exceed the dimensions of the values tensor
1443
Kevin Cheng550ccc52021-03-03 11:21:43 -08001444 K = values.shape[1] # K
1445 W = self.randInt(
1446 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1447 ) # W
1448 indicies_arr = np.int32(
1449 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1450 ) # (N, W)
1451 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001452
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001453 result_tens = OutputShaper.gatherOp(
1454 self.ser, self.rng, values, indicies, error_name
1455 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001456
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001457 # Invalidate Input/Output list for error if checks.
1458 input_list = [values.name, indicies.name]
1459 output_list = [result_tens.name]
1460 pCount, cCount = op["operands"]
1461 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001462 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1463 self, error_name, input_list, output_list
1464 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001465
Les Bell729b0352021-11-24 10:28:21 +00001466 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001467 self.ser,
1468 validator_fcns,
1469 error_name,
1470 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001471 input_shape=values.shape,
1472 output_shape=result_tens.shape,
1473 input_dtype=values.dtype,
1474 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001475 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001476 input_list=input_list,
1477 output_list=output_list,
1478 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001479 ):
1480 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001481
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001482 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001483
1484 return result_tens
1485
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001486 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001487
1488 # Create a new indicies tensor
1489 # here with data that doesn't exceed the dimensions of the values_in tensor
1490
Kevin Cheng550ccc52021-03-03 11:21:43 -08001491 K = values_in.shape[1] # K
1492 W = input.shape[1] # W
1493 indicies_arr = np.int32(
1494 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1495 ) # (N, W)
1496 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001497
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001498 result_tens = OutputShaper.scatterOp(
1499 self.ser, self.rng, values_in, indicies, input, error_name
1500 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001501
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001502 # Invalidate Input/Output list for error if checks.
1503 input_list = [values_in.name, indicies.name, input.name]
1504 output_list = [result_tens.name]
1505 pCount, cCount = op["operands"]
1506 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001507 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1508 self, error_name, input_list, output_list
1509 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001510
Les Bell729b0352021-11-24 10:28:21 +00001511 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001512 self.ser,
1513 validator_fcns,
1514 error_name,
1515 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001516 input_shape=values_in.shape,
1517 output_shape=result_tens.shape,
1518 input_dtype=values_in.dtype,
1519 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001520 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001521 input_list=input_list,
1522 output_list=output_list,
1523 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001524 ):
1525 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001526
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001527 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001528
Kevin Cheng77d0f762020-11-24 10:26:32 -08001529 return result_tens
1530
Kevin Cheng550ccc52021-03-03 11:21:43 -08001531 def build_resize(
1532 self,
1533 op,
1534 input,
1535 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001536 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001537 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001538 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001539 input_dtype,
1540 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001541 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001542 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001543 ):
1544 result_tens = OutputShaper.resizeOp(
1545 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001546 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001547 input,
1548 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001549 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001550 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001551 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001552 input_dtype,
1553 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001554 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001555 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001556
Matthew Haddon848efb42021-09-09 12:30:53 +01001557 # Invalidate Input/Output list for error if checks.
1558 input_list = [input.name]
1559 output_list = [result_tens.name]
1560 pCount, cCount = op["operands"]
1561 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1563 self, error_name, input_list, output_list
1564 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001565
Les Bell729b0352021-11-24 10:28:21 +00001566 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001567 self.ser,
1568 validator_fcns,
1569 error_name,
1570 op=op,
1571 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001572 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001573 input_dtype=input_dtype,
1574 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001575 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001576 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001577 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001578 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001579 input_list=input_list,
1580 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001581 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001582 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001583 ):
1584 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001585
Eric Kunzee5e26762020-10-13 16:11:07 -07001586 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001587
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001588 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001589
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001590 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001591 return result_tens
1592
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001593 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1594 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1595 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001596 self.ser.addOperator(
1597 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1598 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001599 return result_tens
1600
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001601 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001602 self.ser.addOutputTensor(val)
1603 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001604
1605 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001606 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001607 result_tens = OutputShaper.typeConversionOp(
1608 self.ser, self.rng, val, out_dtype, error_name
1609 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001610
1611 # Invalidate Input/Output list for error if checks.
1612 input_list = [val.name]
1613 output_list = [result_tens.name]
1614 pCount, cCount = op["operands"]
1615 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001616 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1617 self, error_name, input_list, output_list
1618 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001619
Les Bell729b0352021-11-24 10:28:21 +00001620 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001621 self.ser,
1622 validator_fcns,
1623 error_name,
1624 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001625 input_shape=val.shape,
1626 output_shape=result_tens.shape,
1627 input_dtype=val.dtype,
1628 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001629 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001630 input_list=input_list,
1631 output_list=output_list,
1632 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001633 ):
1634 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001635
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001636 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001637 return result_tens
1638
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001639 def build_rescale(
1640 self,
1641 op,
1642 val,
1643 out_dtype,
1644 scale32,
1645 double_round,
1646 per_channel,
1647 validator_fcns,
1648 error_name,
1649 ):
1650 result_tens = OutputShaper.typeConversionOp(
1651 self.ser, self.rng, val, out_dtype, error_name
1652 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001653
1654 if per_channel:
1655 nc = val.shape[-1]
1656 else:
1657 nc = 1
1658
1659 in_type_width = self.typeWidth(val.dtype)
1660 out_type_width = self.typeWidth(out_dtype)
1661
Kevin Cheng3a478572021-01-22 17:21:02 -08001662 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001663 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001664 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001665 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001666 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001667 in_type_width += 1
1668 elif error_name in [
1669 ErrorIf.InputZeroPointNotZero,
1670 ErrorIf.U16InputZeroPointNotValid,
1671 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001672 input_zp = self.randInt(-128, 128)
1673 if input_zp == 0:
1674 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001675 in_type_width += 1
1676 elif val.dtype == DType.UINT16:
1677 # Must come after ErrorIf.U16InputZeroPointNotValid check
1678 input_zp = self.rng.choice([0, 32768])
1679 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001680 else:
1681 input_zp = 0
1682
Kevin Cheng3a478572021-01-22 17:21:02 -08001683 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001684 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001685 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001686 elif out_dtype == DType.UINT8:
1687 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001688 out_type_width += 1
1689 elif error_name in [
1690 ErrorIf.OutputZeroPointNotZero,
1691 ErrorIf.U16OutputZeroPointNotValid,
1692 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001693 output_zp = self.randInt(-128, 128)
1694 if output_zp == 0:
1695 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001696 out_type_width += 1
1697 elif out_dtype == DType.UINT16:
1698 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1699 output_zp = self.rng.choice([0, 32768])
1700 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001701 else:
1702 output_zp = 0
1703
1704 # Calculate scale based on:
1705 # scale = a *(2^output_width)/(2^input_width))
1706
1707 a = np.float32(self.rng.random(size=[nc]))
1708 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1709
1710 if scale32:
1711 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001712 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001713 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1714 else:
1715 # Cap the scaling at 2^15 - 1 for scale16
1716 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1717
Kevin Cheng550ccc52021-03-03 11:21:43 -08001718 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001719
1720 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1721 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001722 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1723 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001724
1725 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001726 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1727 scale_arr[i], scale32
1728 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001729 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1730 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001731
Kevin Cheng550ccc52021-03-03 11:21:43 -08001732 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001733 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001734 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001735 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001736 assert val.placeholderFilename
1737 values = np.load(
1738 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1739 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001740 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1741 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1742 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1743 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001744 if not np.all(np.array_equal(values, val_adj)):
1745 # Values changed so overwrite file with new values
1746 np.save(
1747 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1748 val_adj,
1749 False,
1750 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001751
Matthew Haddonc2025212021-10-08 21:21:05 +01001752 # Invalidate Input/Output list for error if checks.
1753 input_list = [val.name]
1754 output_list = [result_tens.name]
1755 pCount, cCount = op["operands"]
1756 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001757 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1758 self, error_name, input_list, output_list
1759 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001760
1761 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001762 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001763 self.ser,
1764 validator_fcns,
1765 error_name,
1766 op=op,
1767 input_dtype=val.dtype,
1768 output_dtype=out_dtype,
1769 input_shape=val.shape,
1770 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001771 scale32=scale32,
1772 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001773 input_list=input_list,
1774 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001775 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001776 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001777 ):
1778 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001779
Eric Kunzee5e26762020-10-13 16:11:07 -07001780 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001781 attr.RescaleAttribute(
1782 input_zp,
1783 output_zp,
1784 multiplier_arr,
1785 shift_arr,
1786 scale32,
1787 double_round,
1788 per_channel,
1789 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001790
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001791 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001792 return result_tens
1793
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001794 def _get_condition_tensor(self, op, cond, error_name):
1795 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1796 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1797 else:
1798 cond_type = DType.BOOL
1799 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1800 choice = self.rng.choice([1, 2])
1801 if choice == 1:
1802 cond_shape = [2]
1803 else:
1804 cond_shape = [1, 2]
1805 else:
1806 # Must be of size 1 (rank 0)
1807 cond_shape = []
1808 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1809 return cond_tens
1810
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 def build_cond_if_const(
1812 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1813 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001814 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001815 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001816 # and fill them with const nodes for the body.
1817
1818 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001819 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001820
1821 # Make then/else tensors
1822 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001823
1824 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001825 if error_name in [
1826 ErrorIf.CondIfOutputListThenGraphMismatch,
1827 ErrorIf.CondIfOutputListElseGraphMismatch,
1828 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001829 incorrect_shape = deepcopy(then_tens.shape)
1830 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001831 incorrect_shape[i] += (
1832 self.rng.choice([-3, -2, 2, 3])
1833 if incorrect_shape[i] > 3
1834 else self.rng.choice([1, 2, 4])
1835 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001836 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1837
Jeremy Johnson18e26662021-07-22 16:15:29 +01001838 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1839 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001840
1841 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001842 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001843
1844 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001845 then_block = "THEN_BLOCK"
1846 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001847 attr = ts.TosaSerializerAttribute()
1848 attr.CondIfAttribute(then_block, else_block)
1849
1850 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001851 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001852
Jerry Ge9e94af82022-10-27 09:57:00 -07001853 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001854 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001855 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1856 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1857 else:
1858 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001859 self.ser.addOutputTensor(then_tens)
1860
Jerry Ge9e94af82022-10-27 09:57:00 -07001861 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001862 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1863 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1864 else:
1865 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001866 self.ser.addOutputTensor(else_tens)
1867
Les Bell729b0352021-11-24 10:28:21 +00001868 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001869 self.ser,
1870 validator_fcns,
1871 error_name,
1872 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001873 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001874 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001875 ):
1876 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001877
Eric Kunzee5e26762020-10-13 16:11:07 -07001878 return result_tens
1879
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001880 def build_cond_if_binary(
1881 self, op, a, b, cond, validator_fcns=None, error_name=None
1882 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001883 # For cond_if with a binary op in the then/else blocks, take a and b and
1884 # alternately add or subtract them based on the condition
1885
1886 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001887 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001888
Kevin Cheng550ccc52021-03-03 11:21:43 -08001889 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001890
1891 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001892 then_block = "THEN_BLOCK"
1893 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001894 attr = ts.TosaSerializerAttribute()
1895 attr.CondIfAttribute(then_block, else_block)
1896
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001897 if error_name in [
1898 ErrorIf.CondIfInputListThenGraphMismatch,
1899 ErrorIf.CondIfInputListElseGraphMismatch,
1900 ErrorIf.CondIfOutputListElseGraphMismatch,
1901 ErrorIf.CondIfOutputListThenGraphMismatch,
1902 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001903 incorrect_shape = a.shape.copy()
1904 for i in range(len(incorrect_shape)):
1905 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1906 incorrect_block_input = deepcopy(a)
1907 incorrect_block_input.shape = incorrect_shape
1908
Eric Kunzee5e26762020-10-13 16:11:07 -07001909 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001910 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001911 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001912 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001913
James Ward24dbc422022-10-19 12:20:31 +01001914 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001915 then_op, else_op = Op.ADD, Op.SUB
1916 elif a.dtype in (DType.INT8, DType.INT16):
1917 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1918 else:
1919 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001920
Les Bell6040b4d2021-10-11 12:50:31 +01001921 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001922 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001923 if (
1924 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1925 and block == then_block
1926 ) or (
1927 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1928 and block == else_block
1929 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001930 self.ser.addInputTensor(incorrect_block_input)
1931 self.ser.addInputTensor(b)
1932 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001933 elif (
1934 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1935 and block == then_block
1936 ) or (
1937 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1938 and block == else_block
1939 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001940 self.ser.addInputTensor(a)
1941 self.ser.addInputTensor(b)
1942 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1943 else:
1944 self.ser.addInputTensor(a)
1945 self.ser.addInputTensor(b)
1946 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001947 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001948
Les Bell729b0352021-11-24 10:28:21 +00001949 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001950 self.ser,
1951 validator_fcns,
1952 error_name,
1953 op=op,
1954 a=a,
1955 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07001956 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001957 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001958 ):
1959 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001960
Eric Kunzee5e26762020-10-13 16:11:07 -07001961 return result_tens
1962
Matthew Haddon630c17c2021-10-14 15:05:41 +01001963 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001964 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001965
Kevin Cheng550ccc52021-03-03 11:21:43 -08001966 cond_block = "COND_BLOCK"
1967 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001968
1969 attr = ts.TosaSerializerAttribute()
1970 attr.WhileLoopAttribute(cond_block, body_block)
1971
1972 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001973 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001974 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001975 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001976
1977 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001978 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1979 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001980 if error_name == ErrorIf.InputListOutputListMismatch:
1981 incorrect_acc = deepcopy(acc)
1982 for i in range(len(incorrect_acc.shape)):
1983 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1984 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1985 else:
1986 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001987
1988 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001989 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001990 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001991 [iter.name, a.name, acc.name],
1992 [iter_out.name, a_out.name, acc_out.name],
1993 attr,
1994 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001995 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001996
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001997 if error_name in [
1998 ErrorIf.InputListCondGraphMismatch,
1999 ErrorIf.InputListBodyGraphInputMismatch,
2000 ErrorIf.InputListBodyGraphOutputMismatch,
2001 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002002 incorrect_iter = deepcopy(iter)
2003 for i in range(len(incorrect_iter.shape)):
2004 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2005 if len(incorrect_iter.shape) == 0:
2006 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2007
2008 incorrect_acc = deepcopy(acc)
2009 for i in range(len(incorrect_acc.shape)):
2010 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2011
Eric Kunzee5e26762020-10-13 16:11:07 -07002012 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002013 self.ser.addBasicBlock(cond_block)
2014
Matthew Haddon630c17c2021-10-14 15:05:41 +01002015 if error_name == ErrorIf.InputListCondGraphMismatch:
2016 self.ser.addInputTensor(incorrect_iter)
2017 self.ser.addInputTensor(a)
2018 self.ser.addInputTensor(incorrect_acc)
2019 else:
2020 self.ser.addInputTensor(iter)
2021 self.ser.addInputTensor(a)
2022 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002023 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002024
2025 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002026 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002027 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002028 cond_type = DType.BOOL
2029 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2030 choice = self.rng.choice([1, 2])
2031 if choice == 1:
2032 cond_shape = [3]
2033 else:
2034 cond_shape = [1, 2]
2035 else:
2036 cond_shape = []
2037 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002038
Kevin Cheng550ccc52021-03-03 11:21:43 -08002039 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002040
2041 # BODY block (input: a, acc, iter, output: a, acc, iter)
2042 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002043 self.ser.addBasicBlock(body_block)
2044
Matthew Haddon630c17c2021-10-14 15:05:41 +01002045 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2046 self.ser.addInputTensor(incorrect_iter)
2047 self.ser.addInputTensor(a)
2048 self.ser.addInputTensor(incorrect_acc)
2049 else:
2050 self.ser.addInputTensor(iter)
2051 self.ser.addInputTensor(a)
2052 self.ser.addInputTensor(acc)
2053
Kevin Cheng550ccc52021-03-03 11:21:43 -08002054 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002055
2056 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002057 iter_body_out = self.ser.addIntermediate(
2058 incorrect_iter.shape, incorrect_iter.dtype
2059 )
2060 acc_body_out = self.ser.addIntermediate(
2061 incorrect_acc.shape, incorrect_acc.dtype
2062 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002063 else:
2064 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2065 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2066
Eric Kunzee5e26762020-10-13 16:11:07 -07002067 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2068 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2069 self.ser.addOutputTensor(iter_body_out)
2070 self.ser.addOutputTensor(a)
2071 self.ser.addOutputTensor(acc_body_out)
2072
Les Bell729b0352021-11-24 10:28:21 +00002073 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002074 self.ser,
2075 validator_fcns,
2076 error_name,
2077 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002078 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002079 ):
2080 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002081
Eric Kunzee5e26762020-10-13 16:11:07 -07002082 return acc_out
2083
Luke Hutton261b7b62023-01-10 14:50:31 +00002084 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2085 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2086
2087 input_names = [val.name]
2088 pCount, cCount = op["operands"]
2089 num_operands = pCount + cCount
2090
2091 output_names = [res.name for res in results]
2092 output_dtypes = [res.dtype for res in results]
2093
2094 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2095 self, error_name, input_names, output_names
2096 )
2097
2098 if not TosaErrorValidator.evValidateErrorIfs(
2099 self.ser,
2100 validator_fcns,
2101 error_name,
2102 op=op,
2103 input_shape=val.shape,
2104 input_dtype=val.dtype,
2105 output_dtype=output_dtypes,
2106 result_tensors=results,
2107 input_list=input_names,
2108 output_list=output_names,
2109 num_operands=num_operands,
2110 ):
2111 return None
2112
2113 self.ser.addOperator(op["op"], input_names, output_names)
2114 return results
2115
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002116 def create_filter_lists(
2117 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2118 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002119 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2120 default_test_rank_range = range(1, 5)
2121 if not shapeFilter:
2122 shapeFilter = [None]
2123
2124 # Calculate the filters based on what is requested and what the operator allows
2125 rmin, rmax = op["rank"]
2126 if rankFilter is not None:
2127 cleanRankFilter = []
2128 # Ensure rankFilter values are allowed by operator
2129 for rank in rankFilter:
2130 if rank >= rmin and rank <= rmax:
2131 cleanRankFilter.append(rank)
2132 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002133 # Ensure default behaviour is bounded by default range or by operator,
2134 # whichever is the smaller range of ranks.
2135 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002136 cleanRankFilter = (
2137 opRankRange
2138 if len(opRankRange) <= len(default_test_rank_range)
2139 else default_test_rank_range
2140 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002141 else:
2142 cleanRankFilter = range(rmin, rmax + 1)
2143
2144 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002145
Matthew Haddon1c00b712021-10-01 15:51:03 +01002146 if dtypeFilter is not None:
2147 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002148 # Create list of operator dtypes filtered by requested dtypes
2149 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002150 if dtype in dtypeFilter or (
2151 isinstance(dtype, list) and dtype[0] in dtypeFilter
2152 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002153 cleanDtypeFilter.append(dtype)
2154 else:
2155 cleanDtypeFilter = dtypes
2156
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002157 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002158 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002159 "shapeFilter": shapeFilter,
2160 "rankFilter": cleanRankFilter,
2161 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002162 }
2163 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002164 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002165 if validator is not None:
2166 validator_info = validator(check=False, op=op)
2167 else:
2168 return None
2169
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002170 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002171
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002172 # Set parameters as required
2173 if error_arguments["rank"] is not None:
2174 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002175 else:
2176 rankFilter = cleanRankFilter
2177
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002178 if error_arguments["dtype"] is not None:
2179 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002180 else:
2181 dtypeFilter = cleanDtypeFilter
2182
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002183 if error_arguments["shape"] is not None:
2184 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002185 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002186 shapeFilter = shapeFilter[
2187 :2
2188 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002189
2190 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002191 "shapeFilter": shapeFilter,
2192 "rankFilter": rankFilter,
2193 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002194 }
2195 return filterDict
2196
Kevin Cheng550ccc52021-03-03 11:21:43 -08002197 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002198 self,
2199 opName,
2200 shapeFilter=[None],
2201 rankFilter=None,
2202 dtypeFilter=None,
2203 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002204 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002205
2206 try:
2207 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002208 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002209 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002210
2211 # Initialize a new random number generator
2212 self.rng = np.random.default_rng(self.random_seed)
2213
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002214 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002215
Eric Kunzee5e26762020-10-13 16:11:07 -07002216 # Test list consists of a tuple of:
2217 # (opName, testNameStr, dtype, shapeList, argumentsList)
2218 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002219 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002220 error_if_validators = op["error_if_validators"]
2221 else:
2222 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002223
Matthew Haddon1c00b712021-10-01 15:51:03 +01002224 for validator in error_if_validators:
2225 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002226 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002227 else:
2228 error_name = None
2229
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002230 filterDict = self.create_filter_lists(
2231 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2232 )
2233 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002234 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002235 cleanRankFilter = filterDict["rankFilter"]
2236 cleanDtypeFilter = filterDict["dtypeFilter"]
2237 cleanShapeFilter = filterDict["shapeFilter"]
2238 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002239
2240 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002241 for t in cleanDtypeFilter:
2242 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002243 # Filter out by rank
2244 if shape is not None and len(shape) != r:
2245 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002246 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002247 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002248
Matthew Haddon74567092021-07-16 15:38:20 +01002249 shapeStr = self.shapeStr(shapeList[0])
2250 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002251
Matthew Haddon74567092021-07-16 15:38:20 +01002252 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2253 argList = []
2254 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002255 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002256 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002257 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002258
Matthew Haddon74567092021-07-16 15:38:20 +01002259 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002260 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002261 if argStr:
2262 testStr = "{}_{}_{}_{}".format(
2263 opName, shapeStr, typeStr, argStr
2264 )
2265 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002266 testStr = "{}_{}_{}".format(
2267 opName, shapeStr, typeStr
2268 )
2269 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002270 if argStr:
2271 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2272 opName, error_name, shapeStr, typeStr, argStr
2273 )
2274 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002275 testStr = "{}_ERRORIF_{}_{}_{}".format(
2276 opName, error_name, shapeStr, typeStr
2277 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002278
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002279 testList.append(
2280 (opName, testStr, t, error_name, shapeList, args)
2281 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002282
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002283 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002284 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2285 if "invalid_test_validators" in op:
2286 invalid_test_validators = op["invalid_test_validators"]
2287 clean_testList = []
2288 for test in testList:
2289 for validator_fcn in invalid_test_validators:
2290 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002291 if validator_fcn(
2292 opName=test[0],
2293 input_dtype=test[2],
2294 shapeList=test[4],
2295 args=test[5],
2296 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002297 remove_test = True
2298 if not remove_test:
2299 clean_testList.append(test)
2300 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002301
2302 return testList
2303
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002304 def serializeTest(
2305 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2306 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002307 try:
2308 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002309 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002310 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002311
2312 # Create a serializer
2313 self.createSerializer(opName, testStr)
2314
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002315 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002316 if "error_if_validators" in op:
2317 error_if_validators = op["error_if_validators"]
2318 else:
2319 error_if_validators = None
2320
Kevin Cheng550ccc52021-03-03 11:21:43 -08002321 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002322 num_operands = pCount + cCount
2323
2324 if isinstance(dtype_or_dtypeList, list):
2325 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002326 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002327 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002328 else:
2329 dtypeList = [dtype_or_dtypeList] * (num_operands)
2330
Kevin Cheng93a16282021-08-31 16:14:03 -07002331 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002332 assert (
2333 len(shapeList) == num_operands
2334 ), "shapeList length {} must match number of operands {}".format(
2335 len(shapeList), num_operands
2336 )
2337 assert (
2338 len(dtypeList) == num_operands
2339 ), "dtypeList length {} must match number of operands {}".format(
2340 len(dtypeList), num_operands
2341 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002342
2343 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002344 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002345 except KeyError:
2346 qgen = None
2347
2348 # Build the random tensor operands and the test
2349 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002350
Matthew Haddon1c00b712021-10-01 15:51:03 +01002351 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002352 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002353 else:
2354 qinfo = None
2355
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002356 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002357
Matthew Haddon1c00b712021-10-01 15:51:03 +01002358 try:
2359 if error_if_validators is None:
2360 if qinfo is not None:
2361 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2362 else:
2363 resultName = build_fcn(self, op, *tens, *testArgs)
2364 else:
2365 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002366 resultName = build_fcn(
2367 self,
2368 op,
2369 *tens,
2370 *testArgs,
2371 validator_fcns=error_if_validators,
2372 error_name=error_name,
2373 qinfo=qinfo,
2374 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002375 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002376 resultName = build_fcn(
2377 self,
2378 op,
2379 *tens,
2380 *testArgs,
2381 validator_fcns=error_if_validators,
2382 error_name=error_name,
2383 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002384 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002385 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002386 raise e
2387
Les Bell729b0352021-11-24 10:28:21 +00002388 if resultName:
2389 # The test is valid, serialize it
2390 self.serialize("test")
2391 else:
2392 # The test is not valid
2393 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002394
Eric Kunzee5e26762020-10-13 16:11:07 -07002395 def createDynamicOpLists(self):
2396
Jeremy Johnson00423432022-09-12 17:27:37 +01002397 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2398 # Already created these lists (can occur when class is initialized more than once)
2399 return
2400
Eric Kunzee5e26762020-10-13 16:11:07 -07002401 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002402 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002403
Kevin Cheng1533b852021-09-01 12:51:58 -07002404 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002405 testName = "conv2d_{}x{}".format(k[0], k[1])
2406 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2407 self.TOSA_OP_LIST[testName]["filter"] = k
2408 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002409
Kevin Cheng550ccc52021-03-03 11:21:43 -08002410 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2411 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2412 "depthwise_conv2d_TEMPLATE"
2413 ].copy()
2414 self.TOSA_OP_LIST[testName]["filter"] = k
2415 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002416
Kevin Cheng550ccc52021-03-03 11:21:43 -08002417 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2418 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2419 "transpose_conv2d_TEMPLATE"
2420 ].copy()
2421 self.TOSA_OP_LIST[testName]["filter"] = k
2422 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002423
Kevin Cheng1533b852021-09-01 12:51:58 -07002424 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2425 for k in KERNELS_3D:
2426 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2427 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2428 self.TOSA_OP_LIST[testName]["filter"] = k
2429 self.TOSA_OP_LIST[testName]["template"] = False
2430
Eric Kunzee5e26762020-10-13 16:11:07 -07002431 # Delete any templates after having created any dynamic ops
2432 # This is a two-pass operation because it's bad practice to delete
2433 # keys from dictionaries while iterating
2434 keyList = []
2435 for k in self.TOSA_OP_LIST:
2436 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002437 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002438 keyList.append(k)
2439 continue
2440 except KeyError:
2441 pass
2442
2443 for k in keyList:
2444 del self.TOSA_OP_LIST[k]
2445
2446 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002447 """Fill in default fields for ops if they aren't already specified.
2448 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002449 for op in self.TOSA_OP_LIST:
2450
2451 # Required fields
2452 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002453 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002454 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002455 raise Exception(
2456 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2457 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002458
2459 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002460 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002461 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002462 raise Exception(
2463 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2464 op
2465 )
2466 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002467
2468 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002469 _ = self.TOSA_OP_LIST[op]["types"]
2470 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002471 raise Exception(
2472 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2473 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002474
2475 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002476 _ = self.TOSA_OP_LIST[op]["op"]
2477 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002478 raise Exception(
2479 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2480 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002481
2482 # Put in default rank range, if missing
2483 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002484 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002485 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002486 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002487
2488 # Tensor operator list
2489 # 'op': op name
2490 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002491 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2492 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002493 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2494 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002495 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002496
Kevin Cheng550ccc52021-03-03 11:21:43 -08002497 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002498 TYPE_INT_FP = [
2499 DType.INT8,
2500 DType.INT16,
2501 DType.INT32,
2502 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002503 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002504 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002505 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002506
Kevin Cheng550ccc52021-03-03 11:21:43 -08002507 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002508 TYPE_FI32 = [
2509 DType.FP32,
2510 DType.FP16,
2511 DType.BF16,
2512 DType.INT32,
2513 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002514 TYPE_FIB = [
2515 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002516 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002517 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002518 DType.INT8,
2519 DType.INT16,
2520 DType.INT32,
2521 DType.BOOL,
2522 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002523 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002524
James Ward24dbc422022-10-19 12:20:31 +01002525 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002526
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002527 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002528 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002529 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002530 [DType.INT8, DType.INT8, DType.INT32],
2531 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002532 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002533 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002534 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002535 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002536 ]
2537
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002538 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002539
2540 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002541 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002542 "argmax": {
2543 "op": Op.ARGMAX,
2544 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002545 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002546 "build_fcn": (
2547 build_argmax,
2548 TosaTensorGen.tgBasic,
2549 TosaTensorValuesGen.tvgDefault,
2550 TosaArgGen.agAxis,
2551 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002552 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002553 "error_if_validators": (
2554 TosaErrorValidator.evAxisSmallerZero,
2555 TosaErrorValidator.evAxisLargerRank,
2556 TosaErrorValidator.evArgmaxOutputRankMismatch,
2557 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2558 TosaErrorValidator.evWrongRank,
2559 TosaErrorValidator.evWrongInputType,
2560 TosaErrorValidator.evWrongOutputType,
2561 TosaErrorValidator.evWrongInputList,
2562 TosaErrorValidator.evWrongOutputList,
2563 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002564 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002565 "avg_pool2d": {
2566 "op": Op.AVG_POOL2D,
2567 "operands": (1, 0),
2568 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002569 "build_fcn": (
2570 build_pool2d,
2571 TosaTensorGen.tgNHWC,
2572 TosaTensorValuesGen.tvgDefault,
2573 TosaArgGen.agPooling,
2574 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002575 "qgen": TosaQuantGen.qgUnary,
2576 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002577 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002578 "error_if_validators": (
2579 TosaErrorValidator.evKernelSmallerOne,
2580 TosaErrorValidator.evStrideSmallerOne,
2581 TosaErrorValidator.evPadSmallerZero,
2582 TosaErrorValidator.evWrongRank,
2583 TosaErrorValidator.evWrongInputType,
2584 TosaErrorValidator.evWrongOutputType,
2585 TosaErrorValidator.evWrongInputList,
2586 TosaErrorValidator.evWrongOutputList,
2587 TosaErrorValidator.evInputZeroPointNotZero,
2588 TosaErrorValidator.evOutputZeroPointNotZero,
2589 TosaErrorValidator.evPadLargerEqualKernel,
2590 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002591 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002592 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002593 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002594 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002595 "conv2d_TEMPLATE": {
2596 "op": Op.CONV2D,
2597 "operands": (1, 2),
2598 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002599 "build_fcn": (
2600 build_conv2d,
2601 TosaTensorGen.tgConv2D,
2602 TosaTensorValuesGen.tvgDefault,
2603 TosaArgGen.agConv,
2604 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002605 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002606 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002607 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2608 "error_if_validators": (
2609 TosaErrorValidator.evWrongInputType,
2610 TosaErrorValidator.evWrongOutputType,
2611 TosaErrorValidator.evWrongInputList,
2612 TosaErrorValidator.evWrongOutputList,
2613 TosaErrorValidator.evInputZeroPointNotZero,
2614 TosaErrorValidator.evWeightZeroPointNotZero,
2615 TosaErrorValidator.evPadSmallerZero,
2616 TosaErrorValidator.evStrideSmallerOne,
2617 TosaErrorValidator.evDilationSmallerOne,
2618 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002619 TosaErrorValidator.evConvOutputShapeMismatch,
2620 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002621 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002622 "template": True,
2623 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002624 # Templated operator. Filled in by createDynamicOpLists
2625 "conv3d_TEMPLATE": {
2626 "op": Op.CONV3D,
2627 "operands": (1, 2),
2628 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002629 "build_fcn": (
2630 build_conv3d,
2631 TosaTensorGen.tgConv3D,
2632 TosaTensorValuesGen.tvgDefault,
2633 TosaArgGen.agConv,
2634 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002635 "qgen": TosaQuantGen.qgConv,
2636 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002637 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2638 "error_if_validators": (
2639 TosaErrorValidator.evWrongInputType,
2640 TosaErrorValidator.evWrongOutputType,
2641 TosaErrorValidator.evWrongInputList,
2642 TosaErrorValidator.evWrongOutputList,
2643 TosaErrorValidator.evInputZeroPointNotZero,
2644 TosaErrorValidator.evWeightZeroPointNotZero,
2645 TosaErrorValidator.evPadSmallerZero,
2646 TosaErrorValidator.evStrideSmallerOne,
2647 TosaErrorValidator.evDilationSmallerOne,
2648 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002649 TosaErrorValidator.evConvOutputShapeMismatch,
2650 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002651 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002652 "template": True,
2653 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002654 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002655 "depthwise_conv2d_TEMPLATE": {
2656 "op": Op.DEPTHWISE_CONV2D,
2657 "operands": (1, 2),
2658 "filter": [1, 1],
2659 "rank": (4, 4),
2660 "build_fcn": (
2661 build_depthwise_conv2d,
2662 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002663 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002664 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002665 ),
2666 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002667 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002668 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2669 "error_if_validators": (
2670 TosaErrorValidator.evWrongInputType,
2671 TosaErrorValidator.evWrongOutputType,
2672 TosaErrorValidator.evWrongInputList,
2673 TosaErrorValidator.evWrongOutputList,
2674 TosaErrorValidator.evInputZeroPointNotZero,
2675 TosaErrorValidator.evWeightZeroPointNotZero,
2676 TosaErrorValidator.evPadSmallerZero,
2677 TosaErrorValidator.evStrideSmallerOne,
2678 TosaErrorValidator.evDilationSmallerOne,
2679 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002680 TosaErrorValidator.evConvOutputShapeMismatch,
2681 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002682 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002683 "template": True,
2684 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002685 "fully_connected": {
2686 "op": Op.FULLY_CONNECTED,
2687 "operands": (1, 2),
2688 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002689 "build_fcn": (
2690 build_fully_connected,
2691 TosaTensorGen.tgFullyConnected,
2692 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002693 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002694 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002695 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002696 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002697 "error_if_validators": (
2698 TosaErrorValidator.evInputZeroPointNotZero,
2699 TosaErrorValidator.evWeightZeroPointNotZero,
2700 TosaErrorValidator.evWrongRank,
2701 TosaErrorValidator.evWrongInputType,
2702 TosaErrorValidator.evWrongOutputType,
2703 TosaErrorValidator.evWrongInputList,
2704 TosaErrorValidator.evWrongOutputList,
2705 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002706 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002707 "matmul": {
2708 "op": Op.MATMUL,
2709 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002710 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002711 "build_fcn": (
2712 build_matmul,
2713 TosaTensorGen.tgMatmul,
2714 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002715 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002716 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002717 "qgen": TosaQuantGen.qgMatmul,
2718 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002719 "error_if_validators": (
2720 TosaErrorValidator.evInputZeroPointNotZero,
2721 TosaErrorValidator.evWrongRank,
2722 TosaErrorValidator.evWrongInputType,
2723 TosaErrorValidator.evWrongOutputType,
2724 TosaErrorValidator.evWrongInputList,
2725 TosaErrorValidator.evWrongOutputList,
2726 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002727 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002728 "max_pool2d": {
2729 "op": Op.MAX_POOL2D,
2730 "operands": (1, 0),
2731 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002732 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002733 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002734 TosaTensorGen.tgNHWC,
2735 TosaTensorValuesGen.tvgDefault,
2736 TosaArgGen.agPooling,
2737 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002738 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002739 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002740 "error_if_validators": (
2741 TosaErrorValidator.evKernelSmallerOne,
2742 TosaErrorValidator.evStrideSmallerOne,
2743 TosaErrorValidator.evPadSmallerZero,
2744 TosaErrorValidator.evWrongRank,
2745 TosaErrorValidator.evWrongInputType,
2746 TosaErrorValidator.evWrongOutputType,
2747 TosaErrorValidator.evWrongInputList,
2748 TosaErrorValidator.evWrongOutputList,
2749 TosaErrorValidator.evPadLargerEqualKernel,
2750 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002751 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002752 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002753 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002754 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002755 "transpose_conv2d_TEMPLATE": {
2756 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002757 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002758 "rank": (4, 4),
2759 "build_fcn": (
2760 build_transpose_conv2d,
2761 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002762 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002763 TosaArgGen.agTransposeConv2D,
2764 ),
2765 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002766 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002767 "invalid_test_validators": (
2768 TosaInvalidValidator.ivHeightWidthInvalid,
2769 TosaInvalidValidator.ivNonPositiveOutputShape,
2770 ),
2771 "error_if_validators": (
2772 TosaErrorValidator.evWrongInputType,
2773 TosaErrorValidator.evWrongOutputType,
2774 TosaErrorValidator.evWrongInputList,
2775 TosaErrorValidator.evWrongOutputList,
2776 TosaErrorValidator.evInputZeroPointNotZero,
2777 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002778 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002779 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002780 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002781 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002782 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002783 "template": True,
2784 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002785 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002786 "clamp": {
2787 "op": Op.CLAMP,
2788 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002789 "build_fcn": (
2790 build_clamp,
2791 TosaTensorGen.tgBasic,
2792 TosaTensorValuesGen.tvgDefault,
2793 None,
2794 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002795 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002796 "error_if_validators": (
2797 TosaErrorValidator.evMaxSmallerMin,
2798 TosaErrorValidator.evWrongInputType,
2799 TosaErrorValidator.evWrongOutputType,
2800 TosaErrorValidator.evWrongInputList,
2801 TosaErrorValidator.evWrongOutputList,
2802 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002803 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002804 "sigmoid": {
2805 "op": Op.SIGMOID,
2806 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002807 "build_fcn": (
2808 build_sigmoid,
2809 TosaTensorGen.tgBasic,
2810 TosaTensorValuesGen.tvgDefault,
2811 None,
2812 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002813 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002814 "error_if_validators": (
2815 TosaErrorValidator.evWrongInputType,
2816 TosaErrorValidator.evWrongOutputType,
2817 TosaErrorValidator.evWrongInputList,
2818 TosaErrorValidator.evWrongOutputList,
2819 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002820 },
2821 "tanh": {
2822 "op": Op.TANH,
2823 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002824 "build_fcn": (
2825 build_tanh,
2826 TosaTensorGen.tgBasic,
2827 TosaTensorValuesGen.tvgDefault,
2828 None,
2829 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002830 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002831 "error_if_validators": (
2832 TosaErrorValidator.evWrongInputType,
2833 TosaErrorValidator.evWrongOutputType,
2834 TosaErrorValidator.evWrongInputList,
2835 TosaErrorValidator.evWrongOutputList,
2836 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002837 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002838 # Elementwise Binary Operators
2839 "add": {
2840 "op": Op.ADD,
2841 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002842 "build_fcn": (
2843 build_binary_broadcast,
2844 TosaTensorGen.tgBroadcastFuzz,
2845 TosaTensorValuesGen.tvgAddSub,
2846 None,
2847 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002848 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002849 "error_if_validators": (
2850 TosaErrorValidator.evRankMismatch,
2851 TosaErrorValidator.evWrongInputType,
2852 TosaErrorValidator.evWrongOutputType,
2853 TosaErrorValidator.evWrongInputList,
2854 TosaErrorValidator.evWrongOutputList,
2855 TosaErrorValidator.evDimensionMismatch,
2856 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002857 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002858 "arithmetic_right_shift": {
2859 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2860 "operands": (2, 0),
2861 "build_fcn": (
2862 build_arithmetic_right_shift,
2863 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002864 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002865 TosaArgGen.agArithmeticRightShift,
2866 ),
2867 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 "error_if_validators": (
2869 TosaErrorValidator.evRankMismatch,
2870 TosaErrorValidator.evWrongInputType,
2871 TosaErrorValidator.evWrongOutputType,
2872 TosaErrorValidator.evWrongInputList,
2873 TosaErrorValidator.evWrongOutputList,
2874 TosaErrorValidator.evDimensionMismatch,
2875 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002876 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002877 "bitwise_and": {
2878 "op": Op.BITWISE_AND,
2879 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002880 "build_fcn": (
2881 build_binary_broadcast,
2882 TosaTensorGen.tgBroadcastFuzz,
2883 TosaTensorValuesGen.tvgDefault,
2884 None,
2885 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002886 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002887 "error_if_validators": (
2888 TosaErrorValidator.evRankMismatch,
2889 TosaErrorValidator.evWrongInputType,
2890 TosaErrorValidator.evWrongOutputType,
2891 TosaErrorValidator.evWrongInputList,
2892 TosaErrorValidator.evWrongOutputList,
2893 TosaErrorValidator.evDimensionMismatch,
2894 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002895 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002896 "bitwise_or": {
2897 "op": Op.BITWISE_OR,
2898 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002899 "build_fcn": (
2900 build_binary_broadcast,
2901 TosaTensorGen.tgBroadcastFuzz,
2902 TosaTensorValuesGen.tvgDefault,
2903 None,
2904 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002905 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002906 "error_if_validators": (
2907 TosaErrorValidator.evRankMismatch,
2908 TosaErrorValidator.evWrongInputType,
2909 TosaErrorValidator.evWrongOutputType,
2910 TosaErrorValidator.evWrongInputList,
2911 TosaErrorValidator.evWrongOutputList,
2912 TosaErrorValidator.evDimensionMismatch,
2913 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002914 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002915 "bitwise_xor": {
2916 "op": Op.BITWISE_XOR,
2917 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002918 "build_fcn": (
2919 build_binary_broadcast,
2920 TosaTensorGen.tgBroadcastFuzz,
2921 TosaTensorValuesGen.tvgDefault,
2922 None,
2923 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002924 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002925 "error_if_validators": (
2926 TosaErrorValidator.evRankMismatch,
2927 TosaErrorValidator.evWrongInputType,
2928 TosaErrorValidator.evWrongOutputType,
2929 TosaErrorValidator.evWrongInputList,
2930 TosaErrorValidator.evWrongOutputList,
2931 TosaErrorValidator.evDimensionMismatch,
2932 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002933 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002934 "intdiv": {
2935 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002936 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002937 "build_fcn": (
2938 build_binary_broadcast,
2939 TosaTensorGen.tgBroadcastFuzz,
2940 TosaTensorValuesGen.tvgIntDiv,
2941 None,
2942 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002943 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002944 "error_if_validators": (
2945 TosaErrorValidator.evRankMismatch,
2946 TosaErrorValidator.evWrongInputType,
2947 TosaErrorValidator.evWrongOutputType,
2948 TosaErrorValidator.evWrongInputList,
2949 TosaErrorValidator.evWrongOutputList,
2950 TosaErrorValidator.evDimensionMismatch,
2951 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002952 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002953 "logical_and": {
2954 "op": Op.LOGICAL_AND,
2955 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002956 "build_fcn": (
2957 build_binary_broadcast,
2958 TosaTensorGen.tgBroadcastFuzz,
2959 TosaTensorValuesGen.tvgDefault,
2960 None,
2961 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002962 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002963 "error_if_validators": (
2964 TosaErrorValidator.evRankMismatch,
2965 TosaErrorValidator.evWrongInputType,
2966 TosaErrorValidator.evWrongOutputType,
2967 TosaErrorValidator.evWrongInputList,
2968 TosaErrorValidator.evWrongOutputList,
2969 TosaErrorValidator.evDimensionMismatch,
2970 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002971 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002972 "logical_left_shift": {
2973 "op": Op.LOGICAL_LEFT_SHIFT,
2974 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002975 "build_fcn": (
2976 build_binary_broadcast,
2977 TosaTensorGen.tgBroadcastFuzz,
2978 TosaTensorValuesGen.tvgLogicalShift,
2979 None,
2980 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002981 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002982 "error_if_validators": (
2983 TosaErrorValidator.evRankMismatch,
2984 TosaErrorValidator.evWrongInputType,
2985 TosaErrorValidator.evWrongOutputType,
2986 TosaErrorValidator.evWrongInputList,
2987 TosaErrorValidator.evWrongOutputList,
2988 TosaErrorValidator.evDimensionMismatch,
2989 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002991 "logical_right_shift": {
2992 "op": Op.LOGICAL_RIGHT_SHIFT,
2993 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002994 "build_fcn": (
2995 build_binary_broadcast,
2996 TosaTensorGen.tgBroadcastFuzz,
2997 TosaTensorValuesGen.tvgLogicalShift,
2998 None,
2999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003000 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003001 "error_if_validators": (
3002 TosaErrorValidator.evRankMismatch,
3003 TosaErrorValidator.evWrongInputType,
3004 TosaErrorValidator.evWrongOutputType,
3005 TosaErrorValidator.evWrongInputList,
3006 TosaErrorValidator.evWrongOutputList,
3007 TosaErrorValidator.evDimensionMismatch,
3008 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003009 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003010 "logical_or": {
3011 "op": Op.LOGICAL_OR,
3012 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003013 "build_fcn": (
3014 build_binary_broadcast,
3015 TosaTensorGen.tgBroadcastFuzz,
3016 TosaTensorValuesGen.tvgDefault,
3017 None,
3018 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003019 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003020 "error_if_validators": (
3021 TosaErrorValidator.evRankMismatch,
3022 TosaErrorValidator.evWrongInputType,
3023 TosaErrorValidator.evWrongOutputType,
3024 TosaErrorValidator.evWrongInputList,
3025 TosaErrorValidator.evWrongOutputList,
3026 TosaErrorValidator.evDimensionMismatch,
3027 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003028 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003029 "logical_xor": {
3030 "op": Op.LOGICAL_XOR,
3031 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003032 "build_fcn": (
3033 build_binary_broadcast,
3034 TosaTensorGen.tgBroadcastFuzz,
3035 TosaTensorValuesGen.tvgDefault,
3036 None,
3037 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003038 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003039 "error_if_validators": (
3040 TosaErrorValidator.evRankMismatch,
3041 TosaErrorValidator.evWrongInputType,
3042 TosaErrorValidator.evWrongOutputType,
3043 TosaErrorValidator.evWrongInputList,
3044 TosaErrorValidator.evWrongOutputList,
3045 TosaErrorValidator.evDimensionMismatch,
3046 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003047 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003048 "maximum": {
3049 "op": Op.MAXIMUM,
3050 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003051 "build_fcn": (
3052 build_binary_broadcast,
3053 TosaTensorGen.tgBroadcastFuzz,
3054 TosaTensorValuesGen.tvgDefault,
3055 None,
3056 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003057 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003058 "error_if_validators": (
3059 TosaErrorValidator.evRankMismatch,
3060 TosaErrorValidator.evWrongInputType,
3061 TosaErrorValidator.evWrongOutputType,
3062 TosaErrorValidator.evWrongInputList,
3063 TosaErrorValidator.evWrongOutputList,
3064 TosaErrorValidator.evDimensionMismatch,
3065 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003066 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003067 "minimum": {
3068 "op": Op.MINIMUM,
3069 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003070 "build_fcn": (
3071 build_binary_broadcast,
3072 TosaTensorGen.tgBroadcastFuzz,
3073 TosaTensorValuesGen.tvgDefault,
3074 None,
3075 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003076 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003077 "error_if_validators": (
3078 TosaErrorValidator.evRankMismatch,
3079 TosaErrorValidator.evWrongInputType,
3080 TosaErrorValidator.evWrongOutputType,
3081 TosaErrorValidator.evWrongInputList,
3082 TosaErrorValidator.evWrongOutputList,
3083 TosaErrorValidator.evDimensionMismatch,
3084 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003085 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003086 "mul": {
3087 "op": Op.MUL,
3088 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003089 "build_fcn": (
3090 build_mul,
3091 TosaTensorGen.tgBroadcastFuzz,
3092 TosaTensorValuesGen.tvgMul,
3093 TosaArgGen.agMul,
3094 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003095 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003096 "error_if_validators": (
3097 TosaErrorValidator.evWrongInputType,
3098 TosaErrorValidator.evWrongOutputType,
3099 TosaErrorValidator.evWrongInputList,
3100 TosaErrorValidator.evWrongOutputList,
3101 TosaErrorValidator.evRankMismatch,
3102 TosaErrorValidator.evDimensionMismatch,
3103 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003104 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003105 "pow": {
3106 "op": Op.POW,
3107 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003108 "build_fcn": (
3109 build_binary_broadcast,
3110 TosaTensorGen.tgBroadcastFuzz,
3111 TosaTensorValuesGen.tvgDefault,
3112 None,
3113 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003114 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003115 "error_if_validators": (
3116 TosaErrorValidator.evRankMismatch,
3117 TosaErrorValidator.evWrongInputType,
3118 TosaErrorValidator.evWrongOutputType,
3119 TosaErrorValidator.evWrongInputList,
3120 TosaErrorValidator.evWrongOutputList,
3121 TosaErrorValidator.evDimensionMismatch,
3122 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003123 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003124 "sub": {
3125 "op": Op.SUB,
3126 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003127 "build_fcn": (
3128 build_binary_broadcast,
3129 TosaTensorGen.tgBroadcastFuzz,
3130 TosaTensorValuesGen.tvgAddSub,
3131 None,
3132 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003133 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003134 "error_if_validators": (
3135 TosaErrorValidator.evRankMismatch,
3136 TosaErrorValidator.evWrongInputType,
3137 TosaErrorValidator.evWrongOutputType,
3138 TosaErrorValidator.evWrongInputList,
3139 TosaErrorValidator.evWrongOutputList,
3140 TosaErrorValidator.evDimensionMismatch,
3141 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 "table": {
3144 "op": Op.TABLE,
3145 # Use the automatic generation functions to create the input array
3146 # but create the table tensor in the build function, as it may be
3147 # a different type from the input
3148 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003149 "build_fcn": (
3150 build_table,
3151 TosaTensorGen.tgBasic,
3152 TosaTensorValuesGen.tvgDefault,
3153 TosaArgGen.agTable,
3154 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003155 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003156 "error_if_validators": (
3157 TosaErrorValidator.evWrongInputType,
3158 TosaErrorValidator.evWrongOutputType,
3159 TosaErrorValidator.evWrongInputList,
3160 TosaErrorValidator.evWrongOutputList,
3161 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003162 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003163 # Elementwise Unary operators
3164 "abs": {
3165 "op": Op.ABS,
3166 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003167 "build_fcn": (
3168 build_unary,
3169 TosaTensorGen.tgBasic,
3170 TosaTensorValuesGen.tvgDefault,
3171 None,
3172 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003173 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003174 "error_if_validators": (
3175 TosaErrorValidator.evWrongInputType,
3176 TosaErrorValidator.evWrongOutputType,
3177 TosaErrorValidator.evWrongInputList,
3178 TosaErrorValidator.evWrongOutputList,
3179 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003180 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003181 "bitwise_not": {
3182 "op": Op.BITWISE_NOT,
3183 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003184 "build_fcn": (
3185 build_unary,
3186 TosaTensorGen.tgBasic,
3187 TosaTensorValuesGen.tvgDefault,
3188 None,
3189 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003190 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003191 "error_if_validators": (
3192 TosaErrorValidator.evWrongInputType,
3193 TosaErrorValidator.evWrongOutputType,
3194 TosaErrorValidator.evWrongInputList,
3195 TosaErrorValidator.evWrongOutputList,
3196 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003197 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003198 "ceil": {
3199 "op": Op.CEIL,
3200 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003201 "build_fcn": (
3202 build_unary,
3203 TosaTensorGen.tgBasic,
3204 TosaTensorValuesGen.tvgDefault,
3205 None,
3206 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003207 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003208 "error_if_validators": (
3209 TosaErrorValidator.evWrongInputType,
3210 TosaErrorValidator.evWrongOutputType,
3211 TosaErrorValidator.evWrongInputList,
3212 TosaErrorValidator.evWrongOutputList,
3213 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003214 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003215 "clz": {
3216 "op": Op.CLZ,
3217 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003218 "build_fcn": (
3219 build_unary,
3220 TosaTensorGen.tgBasic,
3221 TosaTensorValuesGen.tvgDefault,
3222 None,
3223 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003225 "error_if_validators": (
3226 TosaErrorValidator.evWrongInputType,
3227 TosaErrorValidator.evWrongOutputType,
3228 TosaErrorValidator.evWrongInputList,
3229 TosaErrorValidator.evWrongOutputList,
3230 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003231 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003232 "exp": {
3233 "op": Op.EXP,
3234 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003235 "build_fcn": (
3236 build_unary,
3237 TosaTensorGen.tgBasic,
3238 TosaTensorValuesGen.tvgDefault,
3239 None,
3240 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003241 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003242 "error_if_validators": (
3243 TosaErrorValidator.evWrongInputType,
3244 TosaErrorValidator.evWrongOutputType,
3245 TosaErrorValidator.evWrongInputList,
3246 TosaErrorValidator.evWrongOutputList,
3247 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003248 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003249 "floor": {
3250 "op": Op.FLOOR,
3251 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003252 "build_fcn": (
3253 build_unary,
3254 TosaTensorGen.tgBasic,
3255 TosaTensorValuesGen.tvgDefault,
3256 None,
3257 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003258 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003259 "error_if_validators": (
3260 TosaErrorValidator.evWrongInputType,
3261 TosaErrorValidator.evWrongOutputType,
3262 TosaErrorValidator.evWrongInputList,
3263 TosaErrorValidator.evWrongOutputList,
3264 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003265 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003266 "log": {
3267 "op": Op.LOG,
3268 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003269 "build_fcn": (
3270 build_unary,
3271 TosaTensorGen.tgBasic,
3272 TosaTensorValuesGen.tvgDefault,
3273 None,
3274 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003275 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003276 "error_if_validators": (
3277 TosaErrorValidator.evWrongInputType,
3278 TosaErrorValidator.evWrongOutputType,
3279 TosaErrorValidator.evWrongInputList,
3280 TosaErrorValidator.evWrongOutputList,
3281 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 "logical_not": {
3284 "op": Op.LOGICAL_NOT,
3285 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003286 "build_fcn": (
3287 build_unary,
3288 TosaTensorGen.tgBasic,
3289 TosaTensorValuesGen.tvgDefault,
3290 None,
3291 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003292 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003293 "error_if_validators": (
3294 TosaErrorValidator.evWrongInputType,
3295 TosaErrorValidator.evWrongOutputType,
3296 TosaErrorValidator.evWrongInputList,
3297 TosaErrorValidator.evWrongOutputList,
3298 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003300 "negate": {
3301 "op": Op.NEGATE,
3302 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003303 "build_fcn": (
3304 build_unary,
3305 TosaTensorGen.tgBasic,
3306 TosaTensorValuesGen.tvgNegate,
3307 None,
3308 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 "qgen": TosaQuantGen.qgUnary,
3310 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003311 "error_if_validators": (
3312 TosaErrorValidator.evInputZeroPointNotZero,
3313 TosaErrorValidator.evOutputZeroPointNotZero,
3314 TosaErrorValidator.evWrongInputType,
3315 TosaErrorValidator.evWrongOutputType,
3316 TosaErrorValidator.evWrongInputList,
3317 TosaErrorValidator.evWrongOutputList,
3318 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003319 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003320 "reciprocal": {
3321 "op": Op.RECIPROCAL,
3322 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003323 "build_fcn": (
3324 build_unary,
3325 TosaTensorGen.tgBasic,
3326 TosaTensorValuesGen.tvgDefault,
3327 None,
3328 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003329 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003330 "error_if_validators": (
3331 TosaErrorValidator.evWrongInputType,
3332 TosaErrorValidator.evWrongOutputType,
3333 TosaErrorValidator.evWrongInputList,
3334 TosaErrorValidator.evWrongOutputList,
3335 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003336 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 "rsqrt": {
3338 "op": Op.RSQRT,
3339 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003340 "build_fcn": (
3341 build_unary,
3342 TosaTensorGen.tgBasic,
3343 TosaTensorValuesGen.tvgDefault,
3344 None,
3345 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003346 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003347 "error_if_validators": (
3348 TosaErrorValidator.evWrongInputType,
3349 TosaErrorValidator.evWrongOutputType,
3350 TosaErrorValidator.evWrongInputList,
3351 TosaErrorValidator.evWrongOutputList,
3352 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 # Elementwise Ternary operators
3355 "select": {
3356 "op": Op.SELECT,
3357 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003358 "build_fcn": (
3359 build_select,
3360 TosaTensorGen.tgBroadcastFuzz,
3361 TosaTensorValuesGen.tvgSelect,
3362 None,
3363 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003364 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003365 "error_if_validators": (
3366 TosaErrorValidator.evRankMismatch,
3367 TosaErrorValidator.evWrongInputType,
3368 TosaErrorValidator.evWrongOutputType,
3369 TosaErrorValidator.evWrongInputList,
3370 TosaErrorValidator.evWrongOutputList,
3371 TosaErrorValidator.evDimensionMismatch,
3372 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003373 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003374 # Comparison operators
3375 "equal": {
3376 "op": Op.EQUAL,
3377 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003378 "build_fcn": (
3379 build_comparison,
3380 TosaTensorGen.tgBroadcastFuzz,
3381 TosaTensorValuesGen.tvgEqual,
3382 None,
3383 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003385 "error_if_validators": (
3386 TosaErrorValidator.evRankMismatch,
3387 TosaErrorValidator.evWrongInputType,
3388 TosaErrorValidator.evWrongOutputType,
3389 TosaErrorValidator.evWrongInputList,
3390 TosaErrorValidator.evWrongOutputList,
3391 TosaErrorValidator.evDimensionMismatch,
3392 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003393 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003394 "greater_equal": {
3395 "op": Op.GREATER_EQUAL,
3396 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003397 "build_fcn": (
3398 build_comparison,
3399 TosaTensorGen.tgBroadcastFuzz,
3400 TosaTensorValuesGen.tvgDefault,
3401 None,
3402 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003403 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003404 "error_if_validators": (
3405 TosaErrorValidator.evRankMismatch,
3406 TosaErrorValidator.evWrongInputType,
3407 TosaErrorValidator.evWrongOutputType,
3408 TosaErrorValidator.evWrongInputList,
3409 TosaErrorValidator.evWrongOutputList,
3410 TosaErrorValidator.evDimensionMismatch,
3411 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003412 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 "greater": {
3414 "op": Op.GREATER,
3415 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003416 "build_fcn": (
3417 build_comparison,
3418 TosaTensorGen.tgBroadcastFuzz,
3419 TosaTensorValuesGen.tvgDefault,
3420 None,
3421 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003422 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003423 "error_if_validators": (
3424 TosaErrorValidator.evRankMismatch,
3425 TosaErrorValidator.evWrongInputType,
3426 TosaErrorValidator.evWrongOutputType,
3427 TosaErrorValidator.evWrongInputList,
3428 TosaErrorValidator.evWrongOutputList,
3429 TosaErrorValidator.evDimensionMismatch,
3430 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003431 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003432 # Reduction operators
3433 "reduce_all": {
3434 "op": Op.REDUCE_ALL,
3435 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003436 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003437 "build_fcn": (
3438 build_reduce,
3439 TosaTensorGen.tgBasic,
3440 TosaTensorValuesGen.tvgDefault,
3441 TosaArgGen.agAxis,
3442 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003443 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003444 "error_if_validators": (
3445 TosaErrorValidator.evAxisLargerRank,
3446 TosaErrorValidator.evAxisSmallerZero,
3447 TosaErrorValidator.evShapeOfAxisNotOne,
3448 TosaErrorValidator.evWrongInputType,
3449 TosaErrorValidator.evWrongOutputType,
3450 TosaErrorValidator.evWrongRank,
3451 TosaErrorValidator.evWrongInputList,
3452 TosaErrorValidator.evWrongOutputList,
3453 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003454 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 "reduce_any": {
3456 "op": Op.REDUCE_ANY,
3457 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003458 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003459 "build_fcn": (
3460 build_reduce,
3461 TosaTensorGen.tgBasic,
3462 TosaTensorValuesGen.tvgDefault,
3463 TosaArgGen.agAxis,
3464 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003465 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003466 "error_if_validators": (
3467 TosaErrorValidator.evAxisLargerRank,
3468 TosaErrorValidator.evAxisSmallerZero,
3469 TosaErrorValidator.evShapeOfAxisNotOne,
3470 TosaErrorValidator.evWrongInputType,
3471 TosaErrorValidator.evWrongOutputType,
3472 TosaErrorValidator.evWrongRank,
3473 TosaErrorValidator.evWrongInputList,
3474 TosaErrorValidator.evWrongOutputList,
3475 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003476 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003477 "reduce_max": {
3478 "op": Op.REDUCE_MAX,
3479 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003480 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003481 "build_fcn": (
3482 build_reduce,
3483 TosaTensorGen.tgBasic,
3484 TosaTensorValuesGen.tvgDefault,
3485 TosaArgGen.agAxis,
3486 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003487 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003488 "error_if_validators": (
3489 TosaErrorValidator.evAxisLargerRank,
3490 TosaErrorValidator.evAxisSmallerZero,
3491 TosaErrorValidator.evShapeOfAxisNotOne,
3492 TosaErrorValidator.evWrongInputType,
3493 TosaErrorValidator.evWrongOutputType,
3494 TosaErrorValidator.evWrongRank,
3495 TosaErrorValidator.evWrongInputList,
3496 TosaErrorValidator.evWrongOutputList,
3497 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003499 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003500 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003501 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003502 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003503 "build_fcn": (
3504 build_reduce,
3505 TosaTensorGen.tgBasic,
3506 TosaTensorValuesGen.tvgDefault,
3507 TosaArgGen.agAxis,
3508 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003509 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003510 "error_if_validators": (
3511 TosaErrorValidator.evAxisLargerRank,
3512 TosaErrorValidator.evAxisSmallerZero,
3513 TosaErrorValidator.evShapeOfAxisNotOne,
3514 TosaErrorValidator.evWrongInputType,
3515 TosaErrorValidator.evWrongOutputType,
3516 TosaErrorValidator.evWrongRank,
3517 TosaErrorValidator.evWrongInputList,
3518 TosaErrorValidator.evWrongOutputList,
3519 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003520 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003521 "reduce_product": {
3522 "op": Op.REDUCE_PRODUCT,
3523 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003524 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003525 "build_fcn": (
3526 build_reduce,
3527 TosaTensorGen.tgBasic,
3528 TosaTensorValuesGen.tvgDefault,
3529 TosaArgGen.agAxis,
3530 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003531 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003532 "error_if_validators": (
3533 TosaErrorValidator.evAxisLargerRank,
3534 TosaErrorValidator.evAxisSmallerZero,
3535 TosaErrorValidator.evShapeOfAxisNotOne,
3536 TosaErrorValidator.evWrongInputType,
3537 TosaErrorValidator.evWrongOutputType,
3538 TosaErrorValidator.evWrongRank,
3539 TosaErrorValidator.evWrongInputList,
3540 TosaErrorValidator.evWrongOutputList,
3541 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003542 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003543 "reduce_sum": {
3544 "op": Op.REDUCE_SUM,
3545 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003546 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003547 "build_fcn": (
3548 build_reduce,
3549 TosaTensorGen.tgBasic,
3550 TosaTensorValuesGen.tvgReduceSum,
3551 TosaArgGen.agAxis,
3552 ),
James Ward24dbc422022-10-19 12:20:31 +01003553 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 "error_if_validators": (
3555 TosaErrorValidator.evAxisLargerRank,
3556 TosaErrorValidator.evAxisSmallerZero,
3557 TosaErrorValidator.evShapeOfAxisNotOne,
3558 TosaErrorValidator.evWrongInputType,
3559 TosaErrorValidator.evWrongOutputType,
3560 TosaErrorValidator.evWrongRank,
3561 TosaErrorValidator.evWrongInputList,
3562 TosaErrorValidator.evWrongOutputList,
3563 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003564 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003565 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003566 "concat": {
3567 "op": Op.CONCAT,
3568 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003569 "build_fcn": (
3570 build_concat,
3571 TosaTensorGen.tgConcat,
3572 TosaTensorValuesGen.tvgConcat,
3573 TosaArgGen.agAxis,
3574 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003575 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003576 "error_if_validators": (
3577 TosaErrorValidator.evAxisLargerRank,
3578 TosaErrorValidator.evAxisSmallerZero,
3579 TosaErrorValidator.evConcatInputRankMismatch,
3580 TosaErrorValidator.evConcatShapeSumMismatch,
3581 TosaErrorValidator.evConcatInputDimMismatch,
3582 TosaErrorValidator.evWrongInputType,
3583 TosaErrorValidator.evWrongOutputType,
3584 TosaErrorValidator.evWrongOutputList,
3585 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003586 },
3587 "pad": {
3588 "op": Op.PAD,
3589 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003590 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003591 "build_fcn": (
3592 build_pad,
3593 TosaTensorGen.tgBasic,
3594 TosaTensorValuesGen.tvgDefault,
3595 TosaArgGen.agPad,
3596 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003597 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003598 "error_if_validators": (
3599 TosaErrorValidator.evWrongInputType,
3600 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003601 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003602 TosaErrorValidator.evWrongOutputType,
3603 TosaErrorValidator.evWrongInputList,
3604 TosaErrorValidator.evWrongOutputList,
3605 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003606 },
3607 "reshape": {
3608 "op": Op.RESHAPE,
3609 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003610 "build_fcn": (
3611 build_reshape,
3612 TosaTensorGen.tgBasic,
3613 TosaTensorValuesGen.tvgDefault,
3614 TosaArgGen.agReshape,
3615 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003616 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003617 "error_if_validators": (
3618 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3619 TosaErrorValidator.evWrongInputType,
3620 TosaErrorValidator.evWrongOutputType,
3621 TosaErrorValidator.evWrongInputList,
3622 TosaErrorValidator.evWrongOutputList,
3623 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003624 },
3625 "reverse": {
3626 "op": Op.REVERSE,
3627 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003628 "build_fcn": (
3629 build_reverse,
3630 TosaTensorGen.tgBasic,
3631 TosaTensorValuesGen.tvgDefault,
3632 TosaArgGen.agAxis,
3633 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003634 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003635 "error_if_validators": (
3636 TosaErrorValidator.evAxisSmallerZero,
3637 TosaErrorValidator.evAxisLargerRank,
3638 TosaErrorValidator.evWrongInputType,
3639 TosaErrorValidator.evWrongOutputType,
3640 TosaErrorValidator.evWrongInputList,
3641 TosaErrorValidator.evWrongOutputList,
3642 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003643 },
3644 "slice": {
3645 "op": Op.SLICE,
3646 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003647 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003648 "build_fcn": (
3649 build_slice,
3650 TosaTensorGen.tgBasic,
3651 TosaTensorValuesGen.tvgDefault,
3652 TosaArgGen.agSlice,
3653 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003654 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003655 "error_if_validators": (
3656 TosaErrorValidator.evStartSmallerZero,
3657 TosaErrorValidator.evSizeSmallerEqualZero,
3658 TosaErrorValidator.evStartSizeOutsideBounds,
3659 TosaErrorValidator.evSizeOutputShapeMismatch,
3660 TosaErrorValidator.evInputSizeStartLengthMismatch,
3661 TosaErrorValidator.evWrongRank,
3662 TosaErrorValidator.evWrongInputType,
3663 TosaErrorValidator.evWrongOutputType,
3664 TosaErrorValidator.evWrongInputList,
3665 TosaErrorValidator.evWrongOutputList,
3666 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003667 },
3668 "tile": {
3669 "op": Op.TILE,
3670 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003671 "build_fcn": (
3672 build_tile,
3673 TosaTensorGen.tgBasic,
3674 TosaTensorValuesGen.tvgDefault,
3675 TosaArgGen.agTile,
3676 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003677 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003678 "error_if_validators": (
3679 TosaErrorValidator.evWrongInputType,
3680 TosaErrorValidator.evWrongOutputType,
3681 TosaErrorValidator.evWrongInputList,
3682 TosaErrorValidator.evWrongOutputList,
3683 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003684 },
3685 "transpose": {
3686 "op": Op.TRANSPOSE,
3687 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003688 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003689 "build_fcn": (
3690 build_transpose,
3691 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003692 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003693 TosaArgGen.agTranspose,
3694 ),
3695 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003696 "error_if_validators": (
3697 TosaErrorValidator.evIndexOutsideBounds,
3698 TosaErrorValidator.evIndexUsedTwice,
3699 TosaErrorValidator.evWrongInputType,
3700 TosaErrorValidator.evWrongOutputType,
3701 TosaErrorValidator.evWrongInputList,
3702 TosaErrorValidator.evWrongOutputList,
3703 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003704 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003705 # Data nodes
3706 "const": {
3707 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003708 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003709 "build_fcn": (
3710 build_const,
3711 TosaTensorGen.tgBasic,
3712 TosaTensorValuesGen.tvgDefault,
3713 None,
3714 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003715 "types": TYPE_FIB,
3716 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003717 "identity": {
3718 "op": Op.IDENTITY,
3719 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003720 "build_fcn": (
3721 build_unary,
3722 TosaTensorGen.tgBasic,
3723 TosaTensorValuesGen.tvgDefault,
3724 None,
3725 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003726 "types": TYPE_FIB,
3727 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003728 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003729 "gather": {
3730 "op": Op.GATHER,
3731 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3732 "operands": (1, 0),
3733 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003734 "build_fcn": (
3735 build_gather,
3736 TosaTensorGen.tgBasic,
3737 TosaTensorValuesGen.tvgDefault,
3738 None,
3739 ),
James Ward24dbc422022-10-19 12:20:31 +01003740 "types": (
3741 DType.INT8,
3742 DType.INT16,
3743 DType.INT32,
3744 DType.FP16,
3745 DType.BF16,
3746 DType.FP32,
3747 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003748 "error_if_validators": (
3749 TosaErrorValidator.evWrongInputType,
3750 TosaErrorValidator.evWrongOutputType,
3751 TosaErrorValidator.evWrongInputList,
3752 TosaErrorValidator.evWrongOutputList,
3753 TosaErrorValidator.evWrongRank,
3754 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003755 },
3756 "scatter": {
3757 "op": Op.SCATTER,
3758 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003760 "operands": (2, 0),
3761 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003762 "build_fcn": (
3763 build_scatter,
3764 TosaTensorGen.tgScatter,
3765 TosaTensorValuesGen.tvgDefault,
3766 None,
3767 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003768 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 "error_if_validators": (
3770 TosaErrorValidator.evWrongInputType,
3771 TosaErrorValidator.evWrongOutputType,
3772 TosaErrorValidator.evWrongInputList,
3773 TosaErrorValidator.evWrongOutputList,
3774 TosaErrorValidator.evWrongRank,
3775 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003776 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003777 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003778 "resize": {
3779 "op": Op.RESIZE,
3780 "operands": (1, 0),
3781 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003782 "build_fcn": (
3783 build_resize,
3784 TosaTensorGen.tgNHWC,
3785 TosaTensorValuesGen.tvgDefault,
3786 TosaArgGen.agResize,
3787 ),
James Ward24dbc422022-10-19 12:20:31 +01003788 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003789 "invalid_test_validators": (
3790 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 ),
3792 "error_if_validators": (
3793 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003794 TosaErrorValidator.evScaleSmallerEqualZero,
3795 TosaErrorValidator.evScaleNLargerMax,
3796 TosaErrorValidator.evScaleDLargerMax,
3797 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003798 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003799 TosaErrorValidator.evBorderSmallerMin,
3800 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003801 TosaErrorValidator.evWrongInputType,
3802 TosaErrorValidator.evWrongOutputType,
3803 TosaErrorValidator.evWrongRank,
3804 TosaErrorValidator.evWrongInputList,
3805 TosaErrorValidator.evWrongOutputList,
3806 TosaErrorValidator.evBatchMismatch,
3807 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003808 TosaErrorValidator.evResizeOutputShapeMismatch,
3809 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003810 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003811 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003812 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003813 "cast": {
3814 "op": Op.CAST,
3815 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003816 "build_fcn": (
3817 build_cast,
3818 TosaTensorGen.tgBasic,
3819 TosaTensorValuesGen.tvgDefault,
3820 TosaArgGen.agCast,
3821 ),
James Ward8b390432022-08-12 20:48:56 +01003822 "types": (
3823 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003824 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003825 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003826 DType.INT8,
3827 DType.INT16,
3828 DType.INT32,
3829 DType.BOOL,
3830 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003831 "error_if_validators": (
3832 TosaErrorValidator.evWrongInputType,
3833 TosaErrorValidator.evWrongOutputType,
3834 TosaErrorValidator.evWrongInputList,
3835 TosaErrorValidator.evWrongOutputList,
3836 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003837 },
3838 "rescale": {
3839 "op": Op.RESCALE,
3840 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003841 "build_fcn": (
3842 build_rescale,
3843 TosaTensorGen.tgBasic,
3844 TosaTensorValuesGen.tvgDefault,
3845 TosaArgGen.agRescale,
3846 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003847 "types": [
3848 DType.UINT8,
3849 DType.INT8,
3850 DType.INT16,
3851 DType.INT32,
3852 DType.INT48,
3853 DType.UINT16,
3854 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003855 "error_if_validators": (
3856 TosaErrorValidator.evInputZeroPointNotZero,
3857 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003858 TosaErrorValidator.evU16InputZeroPointNotValid,
3859 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003860 TosaErrorValidator.evScaleTrue,
3861 TosaErrorValidator.evScaleNotTrue,
3862 TosaErrorValidator.evWrongInputType,
3863 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003864 TosaErrorValidator.evWrongInputList,
3865 TosaErrorValidator.evWrongOutputList,
3866 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003867 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003868 # Custom
3869 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003870 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003871 # Two varients of cond_if, one that generates one of two constant tensors (no
3872 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3873 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003874 "cond_if_const": {
3875 "op": Op.COND_IF,
3876 "operands": (0, 2),
3877 "build_fcn": (
3878 build_cond_if_const,
3879 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003880 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003881 TosaArgGen.agCondIf,
3882 ),
3883 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003884 "error_if_validators": (
3885 TosaErrorValidator.evOutputListThenGraphMismatch,
3886 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003887 TosaErrorValidator.evCondIfCondNotMatchingBool,
3888 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003889 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003890 },
3891 "cond_if_binary": {
3892 "op": Op.COND_IF,
3893 "operands": (2, 0),
3894 "build_fcn": (
3895 build_cond_if_binary,
3896 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003897 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003898 TosaArgGen.agCondIf,
3899 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003900 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003901 "error_if_validators": (
3902 TosaErrorValidator.evInputListThenGraphMismatch,
3903 TosaErrorValidator.evInputListElseGraphMismatch,
3904 TosaErrorValidator.evOutputListThenGraphMismatch,
3905 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003906 TosaErrorValidator.evCondIfCondNotMatchingBool,
3907 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003908 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003909 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003910 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003911 "while_loop": {
3912 "op": Op.WHILE_LOOP,
3913 "operands": (0, 1),
3914 "build_fcn": (
3915 build_while_loop,
3916 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003917 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003918 TosaArgGen.agWhileLoop,
3919 ),
3920 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003921 "error_if_validators": (
3922 TosaErrorValidator.evInputListOutputListMismatch,
3923 TosaErrorValidator.evInputListCondGraphMismatch,
3924 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3925 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3926 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003927 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003928 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003929 },
Luke Hutton261b7b62023-01-10 14:50:31 +00003930 "rfft2d": {
3931 "op": Op.RFFT2D,
3932 "operands": (1, 0),
3933 "rank": (3, 3),
3934 "build_fcn": (
3935 build_rfft2d,
3936 TosaTensorGen.tgRFFT2d,
3937 TosaTensorValuesGen.tvgDefault,
3938 TosaArgGen.agNone,
3939 ),
3940 "types": [DType.FP32],
3941 "error_if_validators": (
3942 TosaErrorValidator.evWrongInputType,
3943 TosaErrorValidator.evWrongOutputType,
3944 TosaErrorValidator.evWrongInputList,
3945 TosaErrorValidator.evWrongOutputList,
3946 TosaErrorValidator.evWrongRank,
3947 TosaErrorValidator.evBatchMismatch,
3948 TosaErrorValidator.evKernelNotPowerOfTwo,
3949 ),
3950 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003951 }
3952
Kevin Cheng550ccc52021-03-03 11:21:43 -08003953
Eric Kunzee5e26762020-10-13 16:11:07 -07003954class OutputShaper:
3955 # Methods in this class compute the expected output shape and datatype
3956 # for common classes of operations
3957 def __init__(self):
3958 pass
3959
3960 # These methods return arguments that can be used for
3961 # creating a new output tensor
3962 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003963 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3964 if error_name != ErrorIf.RankMismatch:
3965 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003966 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003967
3968 shape = []
3969 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003970 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003971 shape.append(b.shape[i])
3972 else:
3973 shape.append(a.shape[i])
3974
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003975 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003976 all_dtypes = [
3977 DType.INT8,
3978 DType.INT16,
3979 DType.INT32,
3980 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01003981 DType.FP16,
3982 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003983 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003984 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003985 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3986 outputDType = rng.choice(wrong_dtypes)
3987 else:
3988 outputDType = a.dtype
3989
3990 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003991
3992 @staticmethod
3993 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003994 assert len(a.shape) == len(b.shape)
3995 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003996
3997 shape = []
3998 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003999 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004000 shape.append(a.shape[i])
4001
Kevin Cheng550ccc52021-03-03 11:21:43 -08004002 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004003
4004 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004005 def unaryOp(ser, rng, a, error_name=None):
4006 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004007 all_dtypes = [
4008 DType.INT8,
4009 DType.INT16,
4010 DType.INT32,
4011 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004012 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004013 DType.FP16,
4014 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004015 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004016 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4017 outputDType = rng.choice(wrong_dtypes)
4018 else:
4019 outputDType = a.dtype
4020
4021 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004022
4023 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004024 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004025 if error_name != ErrorIf.RankMismatch:
4026 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004027 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004028
4029 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004030 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004031 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004032 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4033 else:
4034 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004035
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004036 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004037 all_dtypes = [
4038 DType.INT8,
4039 DType.INT16,
4040 DType.INT32,
4041 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004042 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004043 DType.FP16,
4044 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004045 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004046 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4047 outputDType = rng.choice(wrong_dtypes)
4048 else:
4049 outputDType = a.dtype
4050
4051 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004052
4053 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004054 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004055 if error_name != ErrorIf.RankMismatch:
4056 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004057 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004058
4059 # Do broadcast
4060 shape = []
4061 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004062 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004063 shape.append(b.shape[i])
4064 else:
4065 shape.append(a.shape[i])
4066
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004067 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004068 wrong_dtypes = [
4069 DType.INT8,
4070 DType.INT16,
4071 DType.INT32,
4072 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004073 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004074 DType.FP16,
4075 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004076 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004077 outputDType = rng.choice(wrong_dtypes)
4078 else:
4079 outputDType = DType.BOOL
4080
4081 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004082
4083 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004084 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004085 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004086 if error_name not in [
4087 ErrorIf.AxisSmallerZero,
4088 ErrorIf.AxisLargerRank,
4089 ErrorIf.ShapeOfAxisNotOne,
4090 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004091 shape[axis] = 1
4092 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4093 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004094
Matthew Haddond6ce7252021-09-29 15:35:44 +01004095 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004096 all_dtypes = [
4097 DType.INT8,
4098 DType.INT16,
4099 DType.INT32,
4100 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004101 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004102 DType.FP16,
4103 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004104 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004105 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4106 outputDType = rng.choice(wrong_dtypes)
4107 else:
4108 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004109
Matthew Haddond6ce7252021-09-29 15:35:44 +01004110 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004111
4112 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004113 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004114 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004115
4116 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4117 del shape[axis]
4118
4119 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4120 remove = rng.choice([True, False])
4121 if remove and len(shape) > 1:
4122 del shape[0]
4123 else:
4124 shape.append(1)
4125 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4126 for i in range(len(shape)):
4127 shape[i] = shape[i] + rng.integers(1, 10)
4128
4129 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004130 all_dtypes = [
4131 DType.INT8,
4132 DType.INT16,
4133 DType.INT32,
4134 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004135 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004136 DType.FP16,
4137 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004138 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004139 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4140 outputDType = rng.choice(wrong_dtypes)
4141 else:
4142 outputDType = DType.INT32
4143
4144 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004145
4146 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004147 def conv2dOp(
4148 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4149 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004150
4151 # IFM: NHWC
4152 # Filter: OHWI
4153 # OFM: NHWC
4154
Kevin Cheng550ccc52021-03-03 11:21:43 -08004155 h = (
4156 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004157 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004158 + padding[0]
4159 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004160 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004161 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004162
Kevin Cheng550ccc52021-03-03 11:21:43 -08004163 w = (
4164 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004165 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004166 + padding[2]
4167 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004168 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004169 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004170
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004171 if error_name == ErrorIf.ConvOutputShapeMismatch:
4172 choices = [1, 2, 3]
4173 change = rng.choice(choices)
4174 # increment in multiples of stride to not hit non-integer error case
4175 if change in [1, 3]:
4176 h = h + (rng.choice(choices) * strides[0])
4177 if change in [2, 3]:
4178 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004179
Eric Kunzee5e26762020-10-13 16:11:07 -07004180 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4181
James Ward8b390432022-08-12 20:48:56 +01004182 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004183 # Pick some potentially correct output dtype if input type is incorrect
4184 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004185 else:
James Ward8b390432022-08-12 20:48:56 +01004186 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004187
4188 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004189 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004190 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004191 else:
4192 excludes = [out_dtype]
4193 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004194 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004195
Kevin Cheng550ccc52021-03-03 11:21:43 -08004196 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004197
4198 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004199 def conv3dOp(
4200 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4201 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004202
4203 # IFM: NDHWC
4204 # Filter: ODHWI
4205 # OFM: NDHWC
4206
4207 d = (
4208 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004209 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004210 + padding[0]
4211 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004212 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004213 ) // strides[0] + 1
4214
4215 h = (
4216 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004217 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004218 + padding[2]
4219 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004220 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004221 ) // strides[1] + 1
4222
4223 w = (
4224 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004225 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004226 + padding[4]
4227 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004228 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004229 ) // strides[2] + 1
4230
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004231 if error_name == ErrorIf.ConvOutputShapeMismatch:
4232 choices = [1, 2, 3, 4]
4233 change = rng.choice(choices)
4234 # increment in multiples of stride to not hit non-integer error case
4235 if change in [1, 4]:
4236 d = d + (rng.choice(choices) * strides[0])
4237 if change in [2, 4]:
4238 h = h + (rng.choice(choices) * strides[1])
4239 if change in [3, 4]:
4240 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004241
Kevin Cheng1533b852021-09-01 12:51:58 -07004242 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4243
James Ward8b390432022-08-12 20:48:56 +01004244 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004245 # Pick some potentially correct output dtype if input type is incorrect
4246 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004247 else:
James Ward8b390432022-08-12 20:48:56 +01004248 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004249
4250 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004251 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004252 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004253 else:
4254 excludes = [out_dtype]
4255 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004256 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004257
4258 return ser.addOutput(ofm_shape, out_dtype)
4259
4260 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004261 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004262 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004263 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004264 # IFM: NHWC
4265 # Filter: HWCM
4266 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004267
Kevin Cheng550ccc52021-03-03 11:21:43 -08004268 h = (
4269 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004270 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004271 + padding[0]
4272 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004273 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004274 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004275
Kevin Cheng550ccc52021-03-03 11:21:43 -08004276 w = (
4277 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004278 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004279 + padding[2]
4280 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004281 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004282 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004283
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004284 if error_name == ErrorIf.ConvOutputShapeMismatch:
4285 choices = [1, 2, 3]
4286 change = rng.choice(choices)
4287 # increment in multiples of stride to not hit non-integer error case
4288 if change in [1, 3]:
4289 h = h + (rng.choice(choices) * strides[0])
4290 if change in [2, 3]:
4291 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004292
Eric Kunzee5e26762020-10-13 16:11:07 -07004293 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4294
James Ward8b390432022-08-12 20:48:56 +01004295 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004296 # Pick some potentially correct output dtype if input type is incorrect
4297 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004298 else:
James Ward8b390432022-08-12 20:48:56 +01004299 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004300
4301 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004302 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004303 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004304 else:
4305 excludes = [out_dtype]
4306 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004307 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004308
Kevin Cheng550ccc52021-03-03 11:21:43 -08004309 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004310
4311 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004312 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004313 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004314 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004315 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004316 h = 1
4317 w = 1
4318 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004319 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4320 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004321
4322 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004323 choices = [1, 2, 3]
4324 change = rng.choice(choices)
4325 # increment in multiples of stride to not hit non-integer error case
4326 if change in [1, 3]:
4327 h = h + (rng.choice(choices) * stride[0])
4328 if change in [2, 3]:
4329 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004330 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004331
4332 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004333 all_dtypes = [
4334 DType.INT8,
4335 DType.INT16,
4336 DType.INT32,
4337 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004338 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004339 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004340 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004341 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004342 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4343 outputDType = rng.choice(wrong_dtypes)
4344 else:
4345 outputDType = ifm.dtype
4346
4347 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004348
4349 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004350 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004351 # input: N, IC
4352 # filter: OC, IC
4353 # output: N, OC
4354
4355 output_shape = [input.shape[0], filter.shape[0]]
4356
James Ward8b390432022-08-12 20:48:56 +01004357 # Validated in arg_gen (also invalidated for ErrorIf)
4358 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004359
Kevin Cheng550ccc52021-03-03 11:21:43 -08004360 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004361
4362 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004363 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004364 # a: N, H, C
4365 # b: N, C, W
4366 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004367
Kevin Cheng2d60f002021-06-09 14:18:32 -07004368 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004369
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004370 if error_name == ErrorIf.WrongOutputType:
4371 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004372 incorrect_types = (
4373 DType.INT4,
4374 DType.INT8,
4375 DType.INT16,
4376 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004377 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004378 DType.FP16,
4379 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004380 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004381 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004382 incorrect_types = (
4383 DType.INT4,
4384 DType.INT8,
4385 DType.INT16,
4386 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004387 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004388 DType.FP16,
4389 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004390 )
James Ward24dbc422022-10-19 12:20:31 +01004391 elif (
4392 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4393 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004394 incorrect_types = (
4395 DType.INT4,
4396 DType.INT8,
4397 DType.INT16,
4398 DType.INT32,
4399 DType.INT48,
4400 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004401 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004402 elif error_name == ErrorIf.WrongInputType:
4403 # Pick some potentially correct output dtype if input type is incorrect
4404 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004405 else:
James Ward8b390432022-08-12 20:48:56 +01004406 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004407
Kevin Cheng550ccc52021-03-03 11:21:43 -08004408 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004409
4410 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004411 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004412 input1 = a[0]
4413 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004414
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004415 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004416 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004417 if not (
4418 # unable to concat tensors of different ranks
4419 error_name == ErrorIf.ConcatInputRankMismatch
4420 # unable to concat tensors along an invalid axis
4421 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004422 ):
4423 for tensor in remaining_inputs:
4424 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004425
Matthew Haddon01c359d2021-10-15 16:30:48 +01004426 if error_name == ErrorIf.ConcatShapeSumMismatch:
4427 output_shape[axis] += rng.integers(5, 10)
4428
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004429 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004430 all_dtypes = {
4431 DType.INT8,
4432 DType.INT16,
4433 DType.INT32,
4434 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004435 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004436 DType.FP16,
4437 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004438 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004439 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4440 outputDType = rng.choice(wrong_dtypes)
4441 else:
4442 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004443
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004444 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004445
4446 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004447 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004448
4449 output_shape = a.shape.copy()
4450
4451 for i in range(len(output_shape)):
4452 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4453
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004454 if error_name == ErrorIf.PadOutputShapeMismatch:
4455 bad_dim = rng.choice(range(len(output_shape)))
4456 output_shape[bad_dim] -= rng.choice([1, 2])
4457
Matthew Haddone807aae2021-10-11 18:12:58 +01004458 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004459 all_dtypes = [
4460 DType.INT8,
4461 DType.INT16,
4462 DType.INT32,
4463 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004464 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004465 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004466 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004467 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004468 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4469 outputDType = rng.choice(wrong_dtypes)
4470 else:
4471 outputDType = a.dtype
4472
4473 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004474
4475 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004476 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004477 output_shape = shape.copy()
4478
Matthew Haddone807aae2021-10-11 18:12:58 +01004479 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4480 for i in range(len(output_shape)):
4481 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4482
4483 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004484 all_dtypes = [
4485 DType.INT8,
4486 DType.INT16,
4487 DType.INT32,
4488 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004489 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004490 DType.FP16,
4491 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004492 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004493 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4494 outputDType = rng.choice(wrong_dtypes)
4495 else:
4496 outputDType = a.dtype
4497
4498 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004499
4500 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004501 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004502
Matthew Haddone807aae2021-10-11 18:12:58 +01004503 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004504 all_dtypes = [
4505 DType.INT8,
4506 DType.INT16,
4507 DType.INT32,
4508 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004509 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004510 DType.FP16,
4511 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004512 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004513 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4514 outputDType = rng.choice(wrong_dtypes)
4515 else:
4516 outputDType = a.dtype
4517
4518 if error_name == ErrorIf.SizeOutputShapeMismatch:
4519 output_shape = size.copy()
4520 for index in range(len(output_shape)):
4521 if output_shape[index] <= 2:
4522 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4523 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004524 output_shape[index] = output_shape[index] + rng.choice(
4525 [-2, -1, 1, 2]
4526 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004527 else:
4528 output_shape = size.copy()
4529
4530 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004531
4532 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004533 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004534
4535 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004536 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004537
4538 for i in range(len(output_shape)):
4539 output_shape[i] = a.shape[i] * multiples[i]
4540
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004541 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004542 all_dtypes = [
4543 DType.INT8,
4544 DType.INT16,
4545 DType.INT32,
4546 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004547 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004548 DType.FP16,
4549 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004550 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004551 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4552 outputDType = rng.choice(wrong_dtypes)
4553 else:
4554 outputDType = a.dtype
4555
4556 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004557
4558 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004559 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004560 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004561
Kevin Cheng550ccc52021-03-03 11:21:43 -08004562 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004563
Matthew Haddone807aae2021-10-11 18:12:58 +01004564 if error_name == ErrorIf.IndexOutsideBounds:
4565 for i in range(len(output_shape)):
4566 output_shape[i] = a.shape[0]
4567 else:
4568 for i in range(len(output_shape)):
4569 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004570
Matthew Haddone807aae2021-10-11 18:12:58 +01004571 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004572 all_dtypes = [
4573 DType.INT8,
4574 DType.INT16,
4575 DType.INT32,
4576 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004577 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004578 DType.FP16,
4579 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004580 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004581 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4582 outputDType = rng.choice(wrong_dtypes)
4583 else:
4584 outputDType = a.dtype
4585
4586 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004587
4588 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004589 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004590 if error_name != ErrorIf.WrongRank:
4591 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004592 assert len(indices.shape) == 2
4593 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004594
Kevin Cheng77d0f762020-11-24 10:26:32 -08004595 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4596
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004597 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004598 all_dtypes = [
4599 DType.INT8,
4600 DType.INT16,
4601 DType.INT32,
4602 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004603 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004604 DType.FP16,
4605 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004606 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004607 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4608 outputDType = rng.choice(wrong_dtypes)
4609 else:
4610 outputDType = values.dtype
4611
4612 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004613
4614 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004615 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004616 if error_name != ErrorIf.WrongRank:
4617 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004618 assert len(indices.shape) == 2
4619 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004620 assert values_in.shape[0] == indices.shape[0] # N
4621 assert input.shape[1] == indices.shape[1] # W
4622 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004623
4624 output_shape = values_in.shape
4625
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004626 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004627 all_dtypes = [
4628 DType.INT8,
4629 DType.INT16,
4630 DType.INT32,
4631 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004632 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004633 DType.FP16,
4634 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004635 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004636 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4637 outputDType = rng.choice(wrong_dtypes)
4638 else:
4639 outputDType = values_in.dtype
4640
4641 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004642
4643 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004644 def tableOp(ser, rng, input, error_name=None):
4645 # Same shape as the input, dtype dependent on input dtype
4646 if error_name != ErrorIf.WrongInputType:
4647 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004648 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004649 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004650 wrong_dtypes = [
4651 DType.INT8,
4652 DType.INT16,
4653 DType.INT32,
4654 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004655 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004656 DType.FP16,
4657 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004658 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004659 wrong_dtypes.remove(output_dtype)
4660 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004661 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004662
4663 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004664 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004665 serializer,
4666 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004667 input,
4668 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004669 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004670 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004671 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004672 input_dtype,
4673 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004674 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004675 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004676 # Calculate OH, OW
4677 scale_y_n = scale[0]
4678 scale_y_d = scale[1]
4679 scale_x_n = scale[2]
4680 scale_x_d = scale[3]
4681 if error_name == ErrorIf.ScaleSmallerEqualZero:
4682 scale_y_n = max(scale_y_n, 1)
4683 scale_y_d = max(scale_y_d, 1)
4684 scale_x_n = max(scale_x_n, 1)
4685 scale_x_d = max(scale_x_d, 1)
4686
4687 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4688 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4689
4690 if error_name is not None:
4691 # Make sure the output tensor is valid, which can occur when
4692 # scale, offset or border have been changed for ERROR_IFs
4693 oh = max(oh, 1)
4694 ow = max(ow, 1)
4695 if error_name != ErrorIf.MaxDimExceeded:
4696 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4697 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4698
4699 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4700 choices = [1, 2, 3]
4701 change = rng.choice(choices)
4702 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4703 if change in [1, 3]:
4704 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4705 oh -= scale_y_d
4706 assert oh > 0 # Should have been caught in agResize
4707 else:
4708 oh += scale_y_d
4709 if change in [2, 3]:
4710 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4711 ow -= scale_x_d
4712 assert ow > 0 # Should have been caught in agResize
4713 else:
4714 ow += scale_x_d
4715
Matthew Haddon848efb42021-09-09 12:30:53 +01004716 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004717 output_dims = [
4718 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004719 oh,
4720 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004721 input.shape[0],
4722 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004723 elif error_name == ErrorIf.BatchMismatch:
4724 output_dims = [
4725 input.shape[0] + rng.integers(1, 10),
4726 oh,
4727 ow,
4728 input.shape[3],
4729 ]
4730 elif error_name == ErrorIf.ChannelMismatch:
4731 output_dims = [
4732 input.shape[0],
4733 oh,
4734 ow,
4735 input.shape[3] + rng.integers(1, 10),
4736 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004737 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004738 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004739
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004740 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004741
4742 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004743 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004744 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004745
4746 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004747 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004748 if error_name == ErrorIf.ConvOutputShapeMismatch:
4749 choices = [1, 2, 3]
4750 change = rng.choice(choices)
4751 if change in [1, 3]:
4752 output_shape[1] = output_shape[1] + rng.choice(choices)
4753 if change in [2, 3]:
4754 output_shape[2] = output_shape[2] + rng.choice(choices)
4755
James Ward8b390432022-08-12 20:48:56 +01004756 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004757 # Pick some potentially correct output dtype if input type is incorrect
4758 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004759 else:
James Ward8b390432022-08-12 20:48:56 +01004760 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004761
4762 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004763 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004764 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004765 else:
4766 excludes = [out_dtype]
4767 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004768 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004769
Kevin Cheng550ccc52021-03-03 11:21:43 -08004770 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00004771
4772 @staticmethod
4773 def rfft2dOp(serializer, rng, value, error_name=None):
4774 outputs = []
4775
4776 input_shape = value.shape
4777 if error_name != ErrorIf.WrongRank:
4778 assert len(input_shape) == 3
4779
4780 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
4781
4782 output_dtype = value.dtype
4783 if error_name == ErrorIf.WrongOutputType:
4784 excludes = [DType.FP32]
4785 wrong_dtypes = list(usableDTypes(excludes=excludes))
4786 output_dtype = rng.choice(wrong_dtypes)
4787 elif error_name == ErrorIf.BatchMismatch:
4788 incorrect_batch = input_shape[0] + rng.integers(1, 10)
4789 output_shape = [incorrect_batch, *input_shape[1:]]
4790
4791 outputs.append(serializer.addOutput(output_shape, output_dtype))
4792 outputs.append(serializer.addOutput(output_shape, output_dtype))
4793 return outputs