blob: fddf942a4ef661a38cbb85a19b0327f389fd3b82 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Jeremy Johnson05c711e2022-12-12 18:00:41 +000017from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010018from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010019from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010020from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000021from tosa.DType import DType
22from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010023
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010026 # Maximum rank of tensor supported by test generator.
27 TOSA_TENSOR_MAX_RANK = 6
28
Eric Kunzee5e26762020-10-13 16:11:07 -070029 def __init__(self, args):
30 self.args = args
31 self.basePath = args.output_dir
32 self.random_seed = args.random_seed
33 self.ser = None
34 self.rng = np.random.default_rng(self.random_seed)
35 self.createDynamicOpLists()
36 self.initOpListDefaults()
37 self.quantGen = TosaQuantGen()
38 # Force makeShape to do a specific starting shape
39 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010040 # Work out floating point range
41 self.random_fp_low = min(args.tensor_fp_value_range)
42 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070043
44 def createSerializer(self, opName, testPath):
45 self.testPath = os.path.join(opName, testPath)
46
47 fullPath = os.path.join(self.basePath, self.testPath)
48 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010049 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070050
51 def getSerializer(self):
52 return self.ser
53
54 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080055 with open(
56 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
57 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070058 fd.write(self.ser.serialize())
59
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
61 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070062
Matthew Haddon74567092021-07-16 15:38:20 +010063 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000064 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010065 seed = self.random_seed + 1
66 self.rng = np.random.default_rng(seed)
67
Eric Kunzee5e26762020-10-13 16:11:07 -070068 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070069 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070071 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070072 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070073 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070074 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010075 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
76 elif dtype == DType.UINT8:
77 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070078 elif dtype == DType.INT16:
79 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010080 elif dtype == DType.UINT16:
81 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070082 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080083 return np.int32(
84 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
85 )
Eric Kunzee5e26762020-10-13 16:11:07 -070086 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080087 return np.int64(
88 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
89 )
James Ward8b390432022-08-12 20:48:56 +010090 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010091 return np.float16(
92 self.rng.uniform(
93 low=self.random_fp_low, high=self.random_fp_high, size=shape
94 )
95 )
James Ward24dbc422022-10-19 12:20:31 +010096 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010097 f32_tensor = np.float32(
98 self.rng.uniform(
99 low=self.random_fp_low, high=self.random_fp_high, size=shape
100 )
101 )
James Ward24dbc422022-10-19 12:20:31 +0100102 # Floor the last 16 bits of each f32 value
103 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100104 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100105 return np.float32(
106 self.rng.uniform(
107 low=self.random_fp_low, high=self.random_fp_high, size=shape
108 )
109 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800111 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700112
Kevin Cheng989cb052021-04-28 16:29:44 -0700113 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 placeholders = []
115
Kevin Cheng989cb052021-04-28 16:29:44 -0700116 assert len(shape_list) == len(dtype_list)
117
118 for idx, shape in enumerate(shape_list):
119 arr = self.getRandTensor(shape, dtype_list[idx])
120 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 return placeholders
123
Kevin Cheng989cb052021-04-28 16:29:44 -0700124 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700125 consts = []
126
Kevin Cheng989cb052021-04-28 16:29:44 -0700127 assert len(shape_list) == len(dtype_list)
128
129 for idx, shape in enumerate(shape_list):
130 arr = self.getRandTensor(shape, dtype_list[idx])
131 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132
133 return consts
134
135 def makeShape(self, rank):
136 if self.targetted_shape:
137 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800138 return np.int32(
139 self.rng.integers(
140 low=self.args.tensor_shape_range[0],
141 high=self.args.tensor_shape_range[1],
142 size=rank,
143 )
144 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
146 def setTargetShape(self, shape):
147 self.targetted_shape = shape
148
149 def randInt(self, low=0, high=256):
150 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
151
152 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100153 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100154 return np.float32(
155 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
156 )
James Ward8b390432022-08-12 20:48:56 +0100157 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100158 return np.float16(
159 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
160 )
James Ward24dbc422022-10-19 12:20:31 +0100161 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 rand_f32 = np.float32(
163 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
164 )
James Ward24dbc422022-10-19 12:20:31 +0100165 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 elif dtype == DType.BOOL:
167 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700168 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700170 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100172 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 elif dtype == DType.INT16:
174 low, high = (-32768, 32768)
175 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800176 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800178 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 # Special size
180 return np.int64(self.rng.integers(low, high, size=1))[0]
181 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800182 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700183
184 return np.int32(self.rng.integers(low, high, size=1))[0]
185
186 def shapeStr(self, shape):
187
188 sStr = []
189 # Convert to strings
190 for i in shape:
191 sStr.append(str(i))
192
Kevin Cheng550ccc52021-03-03 11:21:43 -0800193 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700194
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100195 def typeStr(self, dtype):
196 if isinstance(dtype, list) or isinstance(dtype, tuple):
197 assert len(dtype) >= 2
198 strs = [self.typeStr(t) for t in dtype]
199 # Limit types to the first 2 as the 3rd is the accumulator
200 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100202 if dtype in DTYPE_ATTRIBUTES:
203 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700204 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100205 raise Exception(
206 "Unknown dtype, cannot convert to string: {}".format(dtype)
207 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100209 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100210 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100211 if dtype in DTYPE_ATTRIBUTES:
212 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 # Argument generators
217 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
218 # Where the string descriptor is used to generate the test name and
219 # The build_fcn_arg_list is expanded and passed to the operator test
220 # build function
221
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100222 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
223 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
224
Matthew Haddon848efb42021-09-09 12:30:53 +0100225 # build_placeholder returns an int, ABS/other ops does not
226 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100228 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000229 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000230 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100231 return result_tens
232
233 # Ensure new output type has correct qinfo
234 if error_name == ErrorIf.WrongOutputType:
235 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000236 qinfo = [
237 TosaQuantGen.getZeroPoint(self, a.dtype),
238 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
239 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100240
241 # Invalidate Input/Output list for error if checks.
242 input_list = [a.name]
243 output_list = [result_tens.name]
244 pCount, cCount = op["operands"]
245 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000246 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
247 self, error_name, input_list, output_list
248 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100249
Les Bell729b0352021-11-24 10:28:21 +0000250 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100251 self.ser,
252 validator_fcns,
253 error_name,
254 op=op,
255 input_dtype=a.dtype,
256 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000257 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000258 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100259 input_list=input_list,
260 output_list=output_list,
261 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000262 ):
263 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100264
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000265 attr = None
266 if op["op"] == Op.NEGATE:
267 attr = ts.TosaSerializerAttribute()
268 attr.NegateAttribute(qinfo[0], qinfo[1])
269
270 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271 return result_tens
272
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100273 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000274 result_tens = OutputShaper.binaryBroadcastOp(
275 self.ser, self.rng, a, b, error_name
276 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100277
278 # Invalidate Input/Output list for error if checks.
279 input_list = [a.name, b.name]
280 output_list = [result_tens.name]
281 pCount, cCount = op["operands"]
282 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000283 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
284 self, error_name, input_list, output_list
285 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100286
Les Bell729b0352021-11-24 10:28:21 +0000287 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100288 self.ser,
289 validator_fcns,
290 error_name,
291 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000292 input1=a,
293 input2=b,
294 input_dtype=a.dtype,
295 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000296 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100297 input_list=input_list,
298 output_list=output_list,
299 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000300 ):
301 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100302
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000303 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 return result_tens
305
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100306 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700307 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000308 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700309 return result_tens
310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000311 def build_arithmetic_right_shift(
312 self, op, a, b, round, validator_fcns=None, error_name=None
313 ):
314 result_tens = OutputShaper.binaryBroadcastOp(
315 self.ser, self.rng, a, b, error_name
316 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100317
318 # Invalidate Input/Output list for error if checks.
319 input_list = [a.name, b.name]
320 output_list = [result_tens.name]
321 pCount, cCount = op["operands"]
322 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000323 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
324 self, error_name, input_list, output_list
325 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100326
Les Bell729b0352021-11-24 10:28:21 +0000327 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100328 self.ser,
329 validator_fcns,
330 error_name,
331 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000332 input1=a,
333 input2=b,
334 input_dtype=a.dtype,
335 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000336 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100337 input_list=input_list,
338 output_list=output_list,
339 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000340 ):
341 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800342
343 attr = ts.TosaSerializerAttribute()
344 attr.ArithmeticRightShiftAttribute(round)
345
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000346 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800347 return result_tens
348
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100349 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000350 result_tens = OutputShaper.binaryBroadcastOp(
351 self.ser, self.rng, a, b, error_name
352 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 # Special for multiply:
355 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100356 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700357 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100358 if error_name == ErrorIf.WrongOutputType:
359 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
360 outputDType = self.rng.choice(all_dtypes)
361 result_tens.setDtype(outputDType)
362
363 # Invalidate Input/Output list for error if checks.
364 input_list = [a.name, b.name]
365 output_list = [result_tens.name]
366 pCount, cCount = op["operands"]
367 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000368 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
369 self, error_name, input_list, output_list
370 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100371
Les Bell729b0352021-11-24 10:28:21 +0000372 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100373 self.ser,
374 validator_fcns,
375 error_name,
376 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000377 input1=a,
378 input2=b,
379 input_dtype=a.dtype,
380 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000381 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100382 input_list=input_list,
383 output_list=output_list,
384 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000385 ):
386 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700387
Kevin Chengaee1fac2020-11-11 13:54:06 -0800388 attr = ts.TosaSerializerAttribute()
389 attr.MulAttribute(shift)
390
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000391 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 return result_tens
393
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100394 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
395 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700396
Kevin Chengfe392ce2021-10-18 21:51:55 +0000397 attr = ts.TosaSerializerAttribute()
398 attr.TableAttribute(table)
399
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100400 # Invalidate Input/Output list for error if checks.
401 input_list = [a.name]
402 output_list = [result_tens.name]
403 pCount, cCount = op["operands"]
404 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000405 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
406 self, error_name, input_list, output_list
407 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100408
Les Bell729b0352021-11-24 10:28:21 +0000409 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 self.ser,
411 validator_fcns,
412 error_name,
413 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000414 input_shape=a.shape,
415 input_dtype=a.dtype,
416 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000417 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100418 input_list=input_list,
419 output_list=output_list,
420 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000421 ):
422 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100423
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
426 return result_tens
427
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100428 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
429 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
430
431 # Invalidate Input/Output list for error if checks.
432 input_list = [cond.name, a.name, b.name]
433 output_list = [result_tens.name]
434 pCount, cCount = op["operands"]
435 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
437 self, error_name, input_list, output_list
438 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439
Les Bell729b0352021-11-24 10:28:21 +0000440 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100441 self.ser,
442 validator_fcns,
443 error_name,
444 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000445 input1=cond,
446 input2=a,
447 input3=b,
448 input_shape=a.shape,
449 input_dtype=a.dtype,
450 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000451 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452 input_list=input_list,
453 output_list=output_list,
454 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000455 ):
456 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000458 self.ser.addOperator(
459 op["op"],
460 input_list,
461 output_list,
462 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700463 return result_tens
464
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000466 result_tens = OutputShaper.binaryComparisonOp(
467 self.ser, self.rng, a, b, error_name
468 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100469
470 # Invalidate Input/Output list for error if checks.
471 input_list = [a.name, b.name]
472 output_list = [result_tens.name]
473 pCount, cCount = op["operands"]
474 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000475 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
476 self, error_name, input_list, output_list
477 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478
Les Bell729b0352021-11-24 10:28:21 +0000479 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480 self.ser,
481 validator_fcns,
482 error_name,
483 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000484 input1=a,
485 input2=b,
486 input_shape=a.shape,
487 input_dtype=a.dtype,
488 output_shape=result_tens.shape,
489 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000490 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491 input_list=input_list,
492 output_list=output_list,
493 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000494 ):
495 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 self.ser.addOperator(
498 op["op"],
499 input_list,
500 output_list,
501 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700502 return result_tens
503
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100504 def build_argmax(self, op, a, axis, validator_fcns, error_name):
505 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
506
507 # Invalidate Input/Output list for error if checks.
508 input_list = [a.name]
509 output_list = [result_tens.name]
510 pCount, cCount = op["operands"]
511 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000512 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
513 self, error_name, input_list, output_list
514 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100515
Les Bell729b0352021-11-24 10:28:21 +0000516 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100517 self.ser,
518 validator_fcns,
519 error_name,
520 op=op,
521 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 input_shape=a.shape,
523 input_dtype=a.dtype,
524 output_shape=result_tens.shape,
525 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000526 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527 input_list=input_list,
528 output_list=output_list,
529 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000530 ):
531 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
533 attr = ts.TosaSerializerAttribute()
534 attr.AxisAttribute(axis)
535
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 return result_tens
538
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 def build_pool2d(
540 self,
541 op,
542 input,
James Ward8b390432022-08-12 20:48:56 +0100543 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 stride,
545 pad,
546 kernel,
547 validator_fcns=None,
548 error_name=None,
549 qinfo=None,
550 ):
551 result_tens = OutputShaper.pool2dOp(
552 self.ser, self.rng, input, kernel, stride, pad, error_name
553 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100554
555 # Ensure new output type has correct qinfo
556 if error_name == ErrorIf.WrongInputType:
557 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000558 qinfo = [
559 TosaQuantGen.getZeroPoint(self, input.dtype),
560 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
561 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100562
563 # Invalidate Input/Output list for error if checks.
564 input_list = [input.name]
565 output_list = [result_tens.name]
566 pCount, cCount = op["operands"]
567 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000568 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
569 self, error_name, input_list, output_list
570 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100571
Les Bell729b0352021-11-24 10:28:21 +0000572 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100573 self.ser,
574 validator_fcns,
575 error_name,
576 op=op,
577 input_shape=input.shape,
578 input_dtype=input.dtype,
579 output_shape=result_tens.shape,
580 output_dtype=result_tens.dtype,
581 kernel=kernel,
582 stride=stride,
583 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000585 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100586 input_list=input_list,
587 output_list=output_list,
588 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000589 ):
590 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700591
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000592 if qinfo is None:
593 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700594
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000595 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100596 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000597
598 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700599 return result_tens
600
James Ward8b390432022-08-12 20:48:56 +0100601 def build_maxpool2d(
602 self,
603 op,
604 input,
605 stride,
606 pad,
607 kernel,
608 validator_fcns=None,
609 error_name=None,
610 qinfo=None,
611 ):
612 # Same as build_pool2d but manually sets accum_dtype value
613 # (maxpool has no accum_dtype)
614 return self.build_pool2d(
615 op,
616 input,
617 DType.UNKNOWN,
618 stride,
619 pad,
620 kernel,
621 validator_fcns,
622 error_name,
623 qinfo,
624 )
625
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 def build_conv2d(
627 self,
628 op,
629 ifm,
630 filter,
631 bias,
James Ward8b390432022-08-12 20:48:56 +0100632 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000633 strides,
634 padding,
635 dilations,
636 validator_fcns=None,
637 error_name=None,
638 qinfo=None,
639 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800640 assert len(padding) == 4
641 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100642 self.ser,
643 self.rng,
644 ifm,
645 filter,
646 accum_dtype,
647 strides,
648 padding,
649 dilations,
650 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000651 )
652
653 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
655 DType.INT8,
656 DType.UINT8,
657 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000658 qinfo = [
659 TosaQuantGen.getZeroPoint(self, ifm.dtype),
660 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
661 ]
Les Bell0e027d42021-11-09 14:42:14 +0000662
663 # Invalidate Input/Output list for error_if checks.
664 input_list = [ifm.name, filter.name, bias.name]
665 output_list = [result_tens.name]
666 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
668 self, error_name, input_list, output_list
669 )
Les Bell0e027d42021-11-09 14:42:14 +0000670
Les Bell729b0352021-11-24 10:28:21 +0000671 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000672 self.ser,
673 validator_fcns,
674 error_name,
675 op=op,
676 input_dtype=ifm.dtype,
677 weight_dtype=filter.dtype,
678 output_dtype=result_tens.dtype,
679 qinfo=qinfo,
680 input_list=input_list,
681 num_operands=num_operands,
682 output_list=output_list,
683 pad=padding,
684 stride=strides,
685 dilation=dilations,
686 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100687 weight_shape=filter.shape,
688 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000689 ):
690 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
692 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100693 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000695 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700696 return result_tens
697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 def build_conv3d(
699 self,
700 op,
701 ifm,
702 filter,
703 bias,
James Ward8b390432022-08-12 20:48:56 +0100704 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 strides,
706 padding,
707 dilations,
708 validator_fcns=None,
709 error_name=None,
710 qinfo=None,
711 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700712 assert len(padding) == 6
713 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100714 self.ser,
715 self.rng,
716 ifm,
717 filter,
718 accum_dtype,
719 strides,
720 padding,
721 dilations,
722 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000723 )
724
725 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000726 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
727 DType.INT8,
728 DType.UINT8,
729 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000730 qinfo = [
731 TosaQuantGen.getZeroPoint(self, ifm.dtype),
732 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
733 ]
Les Bell0e027d42021-11-09 14:42:14 +0000734
735 # Invalidate Input/Output list for error_if checks.
736 input_list = [ifm.name, filter.name, bias.name]
737 output_list = [result_tens.name]
738 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000739 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
740 self, error_name, input_list, output_list
741 )
Les Bell0e027d42021-11-09 14:42:14 +0000742
Les Bell729b0352021-11-24 10:28:21 +0000743 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000744 self.ser,
745 validator_fcns,
746 error_name,
747 op=op,
748 input_dtype=ifm.dtype,
749 weight_dtype=filter.dtype,
750 output_dtype=result_tens.dtype,
751 qinfo=qinfo,
752 input_list=input_list,
753 num_operands=num_operands,
754 output_list=output_list,
755 pad=padding,
756 stride=strides,
757 dilation=dilations,
758 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100759 weight_shape=filter.shape,
760 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000761 ):
762 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700763
764 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100765 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Kevin Cheng1533b852021-09-01 12:51:58 -0700766
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000767 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700768 return result_tens
769
Kevin Cheng550ccc52021-03-03 11:21:43 -0800770 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000771 self,
772 op,
773 ifm,
774 filter,
775 bias,
James Ward8b390432022-08-12 20:48:56 +0100776 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000777 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700778 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000779 output_shape,
780 validator_fcns=None,
781 error_name=None,
782 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800783 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700784 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000785 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100786 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 )
Les Bell0e027d42021-11-09 14:42:14 +0000788
789 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
791 DType.INT8,
792 DType.UINT8,
793 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000794 qinfo = [
795 TosaQuantGen.getZeroPoint(self, ifm.dtype),
796 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
797 ]
Les Bell0e027d42021-11-09 14:42:14 +0000798
799 # Invalidate Input/Output list for error_if checks.
800 input_list = [ifm.name, filter.name, bias.name]
801 output_list = [result_tens.name]
802 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
804 self, error_name, input_list, output_list
805 )
Les Bell0e027d42021-11-09 14:42:14 +0000806
Les Bell729b0352021-11-24 10:28:21 +0000807 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000808 self.ser,
809 validator_fcns,
810 error_name,
811 op=op,
812 input_dtype=ifm.dtype,
813 weight_dtype=filter.dtype,
814 output_dtype=result_tens.dtype,
815 qinfo=qinfo,
816 input_list=input_list,
817 num_operands=num_operands,
818 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700819 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000820 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000821 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100822 weight_shape=filter.shape,
823 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000824 ):
825 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700826
827 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100828 attr.TransposeConvAttribute(
829 out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype
830 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700831
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000832 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700833 return result_tens
834
Kevin Cheng550ccc52021-03-03 11:21:43 -0800835 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000836 self,
837 op,
838 ifm,
839 filter,
840 bias,
James Ward8b390432022-08-12 20:48:56 +0100841 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000842 strides,
843 padding,
844 dilations,
845 validator_fcns=None,
846 error_name=None,
847 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800848 ):
849 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100850 self.ser,
851 self.rng,
852 ifm,
853 filter,
854 accum_dtype,
855 strides,
856 padding,
857 dilations,
858 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000859 )
860
861 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
863 DType.INT8,
864 DType.UINT8,
865 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000866 qinfo = [
867 TosaQuantGen.getZeroPoint(self, ifm.dtype),
868 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
869 ]
Les Bell0e027d42021-11-09 14:42:14 +0000870
871 # Invalidate Input/Output list for error_if checks.
872 input_list = [ifm.name, filter.name, bias.name]
873 output_list = [result_tens.name]
874 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000875 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
876 self, error_name, input_list, output_list
877 )
Les Bell0e027d42021-11-09 14:42:14 +0000878
Les Bell729b0352021-11-24 10:28:21 +0000879 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000880 self.ser,
881 validator_fcns,
882 error_name,
883 op=op,
884 input_dtype=ifm.dtype,
885 weight_dtype=filter.dtype,
886 output_dtype=result_tens.dtype,
887 qinfo=qinfo,
888 input_list=input_list,
889 num_operands=num_operands,
890 output_list=output_list,
891 pad=padding,
892 stride=strides,
893 dilation=dilations,
894 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100895 weight_shape=filter.shape,
896 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000897 ):
898 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700899
900 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100901 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700902
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000903 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 return result_tens
905
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100907 self,
908 op,
909 ifm,
910 filter,
911 bias,
912 accum_dtype,
913 validator_fcns=None,
914 error_name=None,
915 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 ):
917 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100918 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000919 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100920
921 # Invalidate Input/Output list for error if checks.
922 input_list = [ifm.name, filter.name, bias.name]
923 output_list = [result_tens.name]
924 pCount, cCount = op["operands"]
925 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000926 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
927 self, error_name, input_list, output_list
928 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100929
Les Bell729b0352021-11-24 10:28:21 +0000930 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100931 self.ser,
932 validator_fcns,
933 error_name,
934 op=op,
935 input_shape=ifm.shape,
936 input_dtype=ifm.dtype,
937 weight_dtype=filter.dtype,
938 output_shape=result_tens.shape,
939 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000941 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100942 input_list=input_list,
943 output_list=output_list,
944 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100945 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000946 ):
947 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700948
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000949 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100950 attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000951
952 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700953 return result_tens
954
James Ward8b390432022-08-12 20:48:56 +0100955 def build_matmul(
956 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
957 ):
958 result_tens = OutputShaper.matmulOp(
959 self.ser, self.rng, a, b, accum_dtype, error_name
960 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100961
962 # Invalidate Input/Output list for error if checks.
963 input_list = [a.name, b.name]
964 output_list = [result_tens.name]
965 pCount, cCount = op["operands"]
966 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
968 self, error_name, input_list, output_list
969 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100970
Les Bell729b0352021-11-24 10:28:21 +0000971 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100972 self.ser,
973 validator_fcns,
974 error_name,
975 op=op,
976 input_shape=a.shape,
977 input_dtype=a.dtype,
978 input2_shape=b.shape,
979 input2_dtype=b.dtype,
980 output_shape=result_tens.shape,
981 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000982 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000983 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100984 input_list=input_list,
985 output_list=output_list,
986 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100987 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000988 ):
989 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100990
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000991 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100992 attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000993
994 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 return result_tens
996
Matthew Haddond6ce7252021-09-29 15:35:44 +0100997 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
998 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
999
1000 # Invalidate Input/Output list for error if checks.
1001 input_list = [a.name]
1002 output_list = [result_tens.name]
1003 pCount, cCount = op["operands"]
1004 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001005 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1006 self, error_name, input_list, output_list
1007 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001008
Les Bell729b0352021-11-24 10:28:21 +00001009 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001010 self.ser,
1011 validator_fcns,
1012 error_name,
1013 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 axis=axis,
1015 input_shape=a.shape,
1016 output_shape=result_tens.shape,
1017 input_dtype=a.dtype,
1018 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001019 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001020 input_list=input_list,
1021 output_list=output_list,
1022 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001023 ):
1024 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001025
1026 attr = ts.TosaSerializerAttribute()
1027 attr.AxisAttribute(axis)
1028
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001030 return result_tens
1031
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001032 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1033 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001034
Jeremy Johnson18e26662021-07-22 16:15:29 +01001035 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001036
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001037 if error_name == ErrorIf.MaxSmallerMin:
1038 # Make sure the numbers are different to invoke this error
1039 while v[0] == v[1]:
1040 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1041 max_val = min(v)
1042 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001044 max_val = max(v)
1045 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001047 # Invalidate Input/Output list for error if checks.
1048 input_list = [a.name]
1049 output_list = [result_tens.name]
1050 pCount, cCount = op["operands"]
1051 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001052 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1053 self, error_name, input_list, output_list
1054 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001055
Les Bell729b0352021-11-24 10:28:21 +00001056 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001057 self.ser,
1058 validator_fcns,
1059 error_name,
1060 op=op,
1061 max_val=max_val,
1062 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001063 input_shape=a.shape,
1064 output_shape=result_tens.shape,
1065 input_dtype=a.dtype,
1066 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001067 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001068 input_list=input_list,
1069 output_list=output_list,
1070 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001071 ):
1072 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001073
1074 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001075 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1076 if a.dtype == DType.FP16:
1077 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1078 min_val = min_val.astype(np.float32)
1079 max_val = max_val.astype(np.float32)
1080
1081 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001082 else:
James Ward34071252022-12-07 15:48:47 +00001083 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001084
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001085 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001086 return result_tens
1087
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001088 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1089 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001090 attr = ts.TosaSerializerAttribute()
1091
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001092 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001093
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001094 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001095 return result_tens
1096
1097 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001098 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1099 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001100
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001101 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001102 return result_tens
1103
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001104 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1105 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1106
1107 # Invalidate Input/Output list for error if checks.
1108 input_list = [a.name]
1109 output_list = [result_tens.name]
1110 pCount, cCount = op["operands"]
1111 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001112 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1113 self, error_name, input_list, output_list
1114 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001115
Les Bell729b0352021-11-24 10:28:21 +00001116 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001117 self.ser,
1118 validator_fcns,
1119 error_name,
1120 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 input_shape=a.shape,
1122 output_shape=result_tens.shape,
1123 input_dtype=a.dtype,
1124 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001125 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001126 input_list=input_list,
1127 output_list=output_list,
1128 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001129 ):
1130 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001131
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001132 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001133 return result_tens
1134
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001135 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1136 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1137
1138 # Invalidate Input/Output list for error if checks.
1139 input_list = [a.name]
1140 output_list = [result_tens.name]
1141 pCount, cCount = op["operands"]
1142 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1144 self, error_name, input_list, output_list
1145 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001146
Les Bell729b0352021-11-24 10:28:21 +00001147 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001148 self.ser,
1149 validator_fcns,
1150 error_name,
1151 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001152 input_shape=a.shape,
1153 output_shape=result_tens.shape,
1154 input_dtype=a.dtype,
1155 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001156 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001157 input_list=input_list,
1158 output_list=output_list,
1159 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001160 ):
1161 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001162
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001163 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001164 return result_tens
1165
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001166 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1167 if error_name != ErrorIf.WrongInputType:
1168 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001169
1170 # To store variable length list of input tensors we need to store axis along with it
1171 axis = a[-1]
1172 a = a[:-1]
1173
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001174 result_tens = OutputShaper.concatOp(
1175 self.ser, self.rng, axis, *a, error_name=error_name
1176 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001177
Matthew Haddon818ab902021-07-27 09:12:49 +01001178 input_tensor_names = []
1179 for tensor in a:
1180 input_tensor_names.append(tensor.name)
1181
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001182 # Invalidate Input/Output list for error if checks.
1183 input_list = input_tensor_names
1184 output_list = [result_tens.name]
1185 pCount, cCount = op["operands"]
1186 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001187 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1188 self, error_name, input_list, output_list
1189 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001190
Les Bell729b0352021-11-24 10:28:21 +00001191 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001192 self.ser,
1193 validator_fcns,
1194 error_name,
1195 op=op,
1196 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 input_shape=a[0].shape,
1198 output_shape=result_tens.shape,
1199 input_dtype=a[0].dtype,
1200 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001201 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001202 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001203 input_list=input_list,
1204 output_list=output_list,
1205 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001206 ):
1207 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001208
1209 attr = ts.TosaSerializerAttribute()
1210 attr.AxisAttribute(axis)
1211
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001212 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001213 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001214
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001215 def build_pad(
1216 self,
1217 op,
1218 a,
1219 padding,
1220 pad_const_int,
1221 pad_const_float,
1222 validator_fcns=None,
1223 error_name=None,
1224 qinfo=None,
1225 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001226 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001227
Kevin Chengfe392ce2021-10-18 21:51:55 +00001228 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001229 attr.PadAttribute(
1230 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1231 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001232
Matthew Haddone807aae2021-10-11 18:12:58 +01001233 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001234 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001235 output_list = [result_tens.name]
1236 pCount, cCount = op["operands"]
1237 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001238 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1239 self, error_name, input_list, output_list
1240 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001241
Les Bell729b0352021-11-24 10:28:21 +00001242 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001243 self.ser,
1244 validator_fcns,
1245 error_name,
1246 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001247 input_shape=a.shape,
1248 output_shape=result_tens.shape,
1249 input_dtype=a.dtype,
1250 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001251 pad=padding,
1252 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001253 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001254 input_list=input_list,
1255 output_list=output_list,
1256 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001257 ):
1258 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001259
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001260 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001261 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001262
Matthew Haddone807aae2021-10-11 18:12:58 +01001263 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 result_tens = OutputShaper.reshapeOp(
1265 self.ser, self.rng, a, newShape, error_name
1266 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001267
1268 # Invalidate Input/Output list for error if checks.
1269 input_list = [a.name]
1270 output_list = [result_tens.name]
1271 pCount, cCount = op["operands"]
1272 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001273 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1274 self, error_name, input_list, output_list
1275 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001276
Les Bell729b0352021-11-24 10:28:21 +00001277 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001278 self.ser,
1279 validator_fcns,
1280 error_name,
1281 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001282 input_shape=a.shape,
1283 output_shape=result_tens.shape,
1284 input_dtype=a.dtype,
1285 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001286 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001287 input_list=input_list,
1288 output_list=output_list,
1289 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001290 ):
1291 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001292
1293 attr = ts.TosaSerializerAttribute()
1294 attr.ReshapeAttribute(newShape)
1295
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001296 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001297 return result_tens
1298
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001299 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1300 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1301
1302 # Invalidate Input/Output list for error if checks.
1303 input_list = [a.name]
1304 output_list = [result_tens.name]
1305 pCount, cCount = op["operands"]
1306 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001307 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1308 self, error_name, input_list, output_list
1309 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001310
Les Bell729b0352021-11-24 10:28:21 +00001311 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001312 self.ser,
1313 validator_fcns,
1314 error_name,
1315 op=op,
1316 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001317 input_shape=a.shape,
1318 output_shape=result_tens.shape,
1319 input_dtype=a.dtype,
1320 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001321 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322 input_list=input_list,
1323 output_list=output_list,
1324 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001325 ):
1326 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001327
1328 attr = ts.TosaSerializerAttribute()
1329 attr.AxisAttribute(axis)
1330
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001331 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001332 return result_tens
1333
Matthew Haddone807aae2021-10-11 18:12:58 +01001334 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1335 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001336
Kevin Chengfe392ce2021-10-18 21:51:55 +00001337 attr = ts.TosaSerializerAttribute()
1338 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001339
Matthew Haddone807aae2021-10-11 18:12:58 +01001340 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001341 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001342 output_list = [result_tens.name]
1343 pCount, cCount = op["operands"]
1344 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001345 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1346 self, error_name, input_list, output_list
1347 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001348
Les Bell729b0352021-11-24 10:28:21 +00001349 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001350 self.ser,
1351 validator_fcns,
1352 error_name,
1353 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001354 input_shape=a.shape,
1355 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001356 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001357 input_dtype=a.dtype,
1358 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001359 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001360 input_list=input_list,
1361 output_list=output_list,
1362 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001363 ):
1364 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001365
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001366 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001367 return result_tens
1368
Matthew Haddone807aae2021-10-11 18:12:58 +01001369 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001370 result_tens = OutputShaper.sliceOp(
1371 self.ser, self.rng, a, start, size, error_name
1372 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001373
1374 # Invalidate Input/Output list for error if checks.
1375 input_list = [a.name]
1376 output_list = [result_tens.name]
1377 pCount, cCount = op["operands"]
1378 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001379 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1380 self, error_name, input_list, output_list
1381 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001382
Les Bell729b0352021-11-24 10:28:21 +00001383 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001384 self.ser,
1385 validator_fcns,
1386 error_name,
1387 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001388 input_shape=a.shape,
1389 output_shape=result_tens.shape,
1390 input_dtype=a.dtype,
1391 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001392 start=start,
1393 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001394 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001395 input_list=input_list,
1396 output_list=output_list,
1397 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001398 ):
1399 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001400
1401 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001402 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001403
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001405 return result_tens
1406
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001407 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1408 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1409
1410 # Invalidate Input/Output list for error if checks.
1411 input_list = [a.name]
1412 output_list = [result_tens.name]
1413 pCount, cCount = op["operands"]
1414 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001415 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1416 self, error_name, input_list, output_list
1417 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418
Les Bell729b0352021-11-24 10:28:21 +00001419 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420 self.ser,
1421 validator_fcns,
1422 error_name,
1423 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001424 input_shape=a.shape,
1425 output_shape=result_tens.shape,
1426 input_dtype=a.dtype,
1427 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001428 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001429 input_list=input_list,
1430 output_list=output_list,
1431 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001432 ):
1433 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001434
1435 attr = ts.TosaSerializerAttribute()
1436 attr.TileAttribute(multiples)
1437
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001438 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001439 return result_tens
1440
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001441 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001442
1443 # Create a new indicies tensor
1444 # here with data that doesn't exceed the dimensions of the values tensor
1445
Kevin Cheng550ccc52021-03-03 11:21:43 -08001446 K = values.shape[1] # K
1447 W = self.randInt(
1448 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1449 ) # W
1450 indicies_arr = np.int32(
1451 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1452 ) # (N, W)
1453 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001454
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001455 result_tens = OutputShaper.gatherOp(
1456 self.ser, self.rng, values, indicies, error_name
1457 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001458
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001459 # Invalidate Input/Output list for error if checks.
1460 input_list = [values.name, indicies.name]
1461 output_list = [result_tens.name]
1462 pCount, cCount = op["operands"]
1463 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001464 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1465 self, error_name, input_list, output_list
1466 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001467
Les Bell729b0352021-11-24 10:28:21 +00001468 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001469 self.ser,
1470 validator_fcns,
1471 error_name,
1472 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001473 input_shape=values.shape,
1474 output_shape=result_tens.shape,
1475 input_dtype=values.dtype,
1476 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001477 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001478 input_list=input_list,
1479 output_list=output_list,
1480 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001481 ):
1482 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001483
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001484 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001485
1486 return result_tens
1487
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001488 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001489
1490 # Create a new indicies tensor
1491 # here with data that doesn't exceed the dimensions of the values_in tensor
1492
Kevin Cheng550ccc52021-03-03 11:21:43 -08001493 K = values_in.shape[1] # K
1494 W = input.shape[1] # W
1495 indicies_arr = np.int32(
1496 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1497 ) # (N, W)
1498 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001499
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001500 result_tens = OutputShaper.scatterOp(
1501 self.ser, self.rng, values_in, indicies, input, error_name
1502 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001503
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001504 # Invalidate Input/Output list for error if checks.
1505 input_list = [values_in.name, indicies.name, input.name]
1506 output_list = [result_tens.name]
1507 pCount, cCount = op["operands"]
1508 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001509 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1510 self, error_name, input_list, output_list
1511 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001512
Les Bell729b0352021-11-24 10:28:21 +00001513 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001514 self.ser,
1515 validator_fcns,
1516 error_name,
1517 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001518 input_shape=values_in.shape,
1519 output_shape=result_tens.shape,
1520 input_dtype=values_in.dtype,
1521 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001522 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001523 input_list=input_list,
1524 output_list=output_list,
1525 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001526 ):
1527 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001528
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001529 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001530
Kevin Cheng77d0f762020-11-24 10:26:32 -08001531 return result_tens
1532
Kevin Cheng550ccc52021-03-03 11:21:43 -08001533 def build_resize(
1534 self,
1535 op,
1536 input,
1537 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001538 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001539 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001540 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001541 input_dtype,
1542 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001543 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001544 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001545 ):
1546 result_tens = OutputShaper.resizeOp(
1547 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001548 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001549 input,
1550 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001551 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001552 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001553 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001554 input_dtype,
1555 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001556 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001557 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001558
Matthew Haddon848efb42021-09-09 12:30:53 +01001559 # Invalidate Input/Output list for error if checks.
1560 input_list = [input.name]
1561 output_list = [result_tens.name]
1562 pCount, cCount = op["operands"]
1563 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1565 self, error_name, input_list, output_list
1566 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001567
Les Bell729b0352021-11-24 10:28:21 +00001568 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001569 self.ser,
1570 validator_fcns,
1571 error_name,
1572 op=op,
1573 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001574 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001575 input_dtype=input_dtype,
1576 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001577 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001578 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001579 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001580 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001581 input_list=input_list,
1582 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001583 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001584 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001585 ):
1586 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001587
Eric Kunzee5e26762020-10-13 16:11:07 -07001588 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001589
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001590 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001591
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001592 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001593 return result_tens
1594
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001595 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1596 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1597 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001598 self.ser.addOperator(
1599 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1600 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001601 return result_tens
1602
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001603 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001604 self.ser.addOutputTensor(val)
1605 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001606
1607 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001608 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001609 result_tens = OutputShaper.typeConversionOp(
1610 self.ser, self.rng, val, out_dtype, error_name
1611 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001612
1613 # Invalidate Input/Output list for error if checks.
1614 input_list = [val.name]
1615 output_list = [result_tens.name]
1616 pCount, cCount = op["operands"]
1617 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001618 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1619 self, error_name, input_list, output_list
1620 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001621
Les Bell729b0352021-11-24 10:28:21 +00001622 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001623 self.ser,
1624 validator_fcns,
1625 error_name,
1626 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001627 input_shape=val.shape,
1628 output_shape=result_tens.shape,
1629 input_dtype=val.dtype,
1630 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001631 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001632 input_list=input_list,
1633 output_list=output_list,
1634 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001635 ):
1636 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001637
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001638 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001639 return result_tens
1640
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001641 def build_rescale(
1642 self,
1643 op,
1644 val,
1645 out_dtype,
1646 scale32,
1647 double_round,
1648 per_channel,
1649 validator_fcns,
1650 error_name,
1651 ):
1652 result_tens = OutputShaper.typeConversionOp(
1653 self.ser, self.rng, val, out_dtype, error_name
1654 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001655
1656 if per_channel:
1657 nc = val.shape[-1]
1658 else:
1659 nc = 1
1660
1661 in_type_width = self.typeWidth(val.dtype)
1662 out_type_width = self.typeWidth(out_dtype)
1663
Kevin Cheng3a478572021-01-22 17:21:02 -08001664 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001665 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001666 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001667 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001668 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001669 in_type_width += 1
1670 elif error_name in [
1671 ErrorIf.InputZeroPointNotZero,
1672 ErrorIf.U16InputZeroPointNotValid,
1673 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001674 input_zp = self.randInt(-128, 128)
1675 if input_zp == 0:
1676 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001677 in_type_width += 1
1678 elif val.dtype == DType.UINT16:
1679 # Must come after ErrorIf.U16InputZeroPointNotValid check
1680 input_zp = self.rng.choice([0, 32768])
1681 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001682 else:
1683 input_zp = 0
1684
Kevin Cheng3a478572021-01-22 17:21:02 -08001685 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001686 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001687 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001688 elif out_dtype == DType.UINT8:
1689 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001690 out_type_width += 1
1691 elif error_name in [
1692 ErrorIf.OutputZeroPointNotZero,
1693 ErrorIf.U16OutputZeroPointNotValid,
1694 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001695 output_zp = self.randInt(-128, 128)
1696 if output_zp == 0:
1697 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001698 out_type_width += 1
1699 elif out_dtype == DType.UINT16:
1700 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1701 output_zp = self.rng.choice([0, 32768])
1702 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001703 else:
1704 output_zp = 0
1705
1706 # Calculate scale based on:
1707 # scale = a *(2^output_width)/(2^input_width))
1708
1709 a = np.float32(self.rng.random(size=[nc]))
1710 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1711
1712 if scale32:
1713 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001714 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001715 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1716 else:
1717 # Cap the scaling at 2^15 - 1 for scale16
1718 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1719
Kevin Cheng550ccc52021-03-03 11:21:43 -08001720 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001721
1722 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1723 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001724 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1725 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
1727 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001728 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1729 scale_arr[i], scale32
1730 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001731 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1732 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001733
Kevin Cheng550ccc52021-03-03 11:21:43 -08001734 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001735 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001736 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001737 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001738 assert val.placeholderFilename
1739 values = np.load(
1740 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1741 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001742 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1743 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1744 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1745 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001746 if not np.all(np.array_equal(values, val_adj)):
1747 # Values changed so overwrite file with new values
1748 np.save(
1749 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1750 val_adj,
1751 False,
1752 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001753
Matthew Haddonc2025212021-10-08 21:21:05 +01001754 # Invalidate Input/Output list for error if checks.
1755 input_list = [val.name]
1756 output_list = [result_tens.name]
1757 pCount, cCount = op["operands"]
1758 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001759 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1760 self, error_name, input_list, output_list
1761 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001762
1763 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001764 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001765 self.ser,
1766 validator_fcns,
1767 error_name,
1768 op=op,
1769 input_dtype=val.dtype,
1770 output_dtype=out_dtype,
1771 input_shape=val.shape,
1772 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 scale32=scale32,
1774 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001775 input_list=input_list,
1776 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001777 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001778 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001779 ):
1780 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001781
Eric Kunzee5e26762020-10-13 16:11:07 -07001782 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001783 attr.RescaleAttribute(
1784 input_zp,
1785 output_zp,
1786 multiplier_arr,
1787 shift_arr,
1788 scale32,
1789 double_round,
1790 per_channel,
1791 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001792
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001793 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001794 return result_tens
1795
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001796 def _get_condition_tensor(self, op, cond, error_name):
1797 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1798 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1799 else:
1800 cond_type = DType.BOOL
1801 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1802 choice = self.rng.choice([1, 2])
1803 if choice == 1:
1804 cond_shape = [2]
1805 else:
1806 cond_shape = [1, 2]
1807 else:
1808 # Must be of size 1 (rank 0)
1809 cond_shape = []
1810 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1811 return cond_tens
1812
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 def build_cond_if_const(
1814 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1815 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001816 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001817 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001818 # and fill them with const nodes for the body.
1819
1820 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001821 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001822
1823 # Make then/else tensors
1824 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001825
1826 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 if error_name in [
1828 ErrorIf.CondIfOutputListThenGraphMismatch,
1829 ErrorIf.CondIfOutputListElseGraphMismatch,
1830 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001831 incorrect_shape = deepcopy(then_tens.shape)
1832 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001833 incorrect_shape[i] += (
1834 self.rng.choice([-3, -2, 2, 3])
1835 if incorrect_shape[i] > 3
1836 else self.rng.choice([1, 2, 4])
1837 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001838 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1839
Jeremy Johnson18e26662021-07-22 16:15:29 +01001840 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1841 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001842
1843 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001844 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001845
1846 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001847 then_block = "THEN_BLOCK"
1848 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001849 attr = ts.TosaSerializerAttribute()
1850 attr.CondIfAttribute(then_block, else_block)
1851
1852 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001853 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001854
Jerry Ge9e94af82022-10-27 09:57:00 -07001855 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001856 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001857 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1858 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1859 else:
1860 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001861 self.ser.addOutputTensor(then_tens)
1862
Jerry Ge9e94af82022-10-27 09:57:00 -07001863 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001864 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1865 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1866 else:
1867 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001868 self.ser.addOutputTensor(else_tens)
1869
Les Bell729b0352021-11-24 10:28:21 +00001870 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001871 self.ser,
1872 validator_fcns,
1873 error_name,
1874 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001875 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001876 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001877 ):
1878 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001879
Eric Kunzee5e26762020-10-13 16:11:07 -07001880 return result_tens
1881
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001882 def build_cond_if_binary(
1883 self, op, a, b, cond, validator_fcns=None, error_name=None
1884 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001885 # For cond_if with a binary op in the then/else blocks, take a and b and
1886 # alternately add or subtract them based on the condition
1887
1888 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001889 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001890
Kevin Cheng550ccc52021-03-03 11:21:43 -08001891 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001892
1893 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001894 then_block = "THEN_BLOCK"
1895 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001896 attr = ts.TosaSerializerAttribute()
1897 attr.CondIfAttribute(then_block, else_block)
1898
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001899 if error_name in [
1900 ErrorIf.CondIfInputListThenGraphMismatch,
1901 ErrorIf.CondIfInputListElseGraphMismatch,
1902 ErrorIf.CondIfOutputListElseGraphMismatch,
1903 ErrorIf.CondIfOutputListThenGraphMismatch,
1904 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001905 incorrect_shape = a.shape.copy()
1906 for i in range(len(incorrect_shape)):
1907 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1908 incorrect_block_input = deepcopy(a)
1909 incorrect_block_input.shape = incorrect_shape
1910
Eric Kunzee5e26762020-10-13 16:11:07 -07001911 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001912 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001913 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001914 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001915
James Ward24dbc422022-10-19 12:20:31 +01001916 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001917 then_op, else_op = Op.ADD, Op.SUB
1918 elif a.dtype in (DType.INT8, DType.INT16):
1919 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1920 else:
1921 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001922
Les Bell6040b4d2021-10-11 12:50:31 +01001923 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001924 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001925 if (
1926 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1927 and block == then_block
1928 ) or (
1929 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1930 and block == else_block
1931 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001932 self.ser.addInputTensor(incorrect_block_input)
1933 self.ser.addInputTensor(b)
1934 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001935 elif (
1936 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1937 and block == then_block
1938 ) or (
1939 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1940 and block == else_block
1941 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001942 self.ser.addInputTensor(a)
1943 self.ser.addInputTensor(b)
1944 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1945 else:
1946 self.ser.addInputTensor(a)
1947 self.ser.addInputTensor(b)
1948 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001949 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001950
Les Bell729b0352021-11-24 10:28:21 +00001951 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001952 self.ser,
1953 validator_fcns,
1954 error_name,
1955 op=op,
1956 a=a,
1957 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07001958 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001959 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001960 ):
1961 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001962
Eric Kunzee5e26762020-10-13 16:11:07 -07001963 return result_tens
1964
Matthew Haddon630c17c2021-10-14 15:05:41 +01001965 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001966 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001967
Kevin Cheng550ccc52021-03-03 11:21:43 -08001968 cond_block = "COND_BLOCK"
1969 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001970
1971 attr = ts.TosaSerializerAttribute()
1972 attr.WhileLoopAttribute(cond_block, body_block)
1973
1974 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001975 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001976 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001977 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001978
1979 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001980 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1981 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001982 if error_name == ErrorIf.InputListOutputListMismatch:
1983 incorrect_acc = deepcopy(acc)
1984 for i in range(len(incorrect_acc.shape)):
1985 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1986 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1987 else:
1988 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001989
1990 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001991 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001992 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001993 [iter.name, a.name, acc.name],
1994 [iter_out.name, a_out.name, acc_out.name],
1995 attr,
1996 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001997 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001998
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001999 if error_name in [
2000 ErrorIf.InputListCondGraphMismatch,
2001 ErrorIf.InputListBodyGraphInputMismatch,
2002 ErrorIf.InputListBodyGraphOutputMismatch,
2003 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002004 incorrect_iter = deepcopy(iter)
2005 for i in range(len(incorrect_iter.shape)):
2006 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2007 if len(incorrect_iter.shape) == 0:
2008 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2009
2010 incorrect_acc = deepcopy(acc)
2011 for i in range(len(incorrect_acc.shape)):
2012 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2013
Eric Kunzee5e26762020-10-13 16:11:07 -07002014 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002015 self.ser.addBasicBlock(cond_block)
2016
Matthew Haddon630c17c2021-10-14 15:05:41 +01002017 if error_name == ErrorIf.InputListCondGraphMismatch:
2018 self.ser.addInputTensor(incorrect_iter)
2019 self.ser.addInputTensor(a)
2020 self.ser.addInputTensor(incorrect_acc)
2021 else:
2022 self.ser.addInputTensor(iter)
2023 self.ser.addInputTensor(a)
2024 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002025 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002026
2027 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002028 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002029 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002030 cond_type = DType.BOOL
2031 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2032 choice = self.rng.choice([1, 2])
2033 if choice == 1:
2034 cond_shape = [3]
2035 else:
2036 cond_shape = [1, 2]
2037 else:
2038 cond_shape = []
2039 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002040
Kevin Cheng550ccc52021-03-03 11:21:43 -08002041 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002042
2043 # BODY block (input: a, acc, iter, output: a, acc, iter)
2044 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002045 self.ser.addBasicBlock(body_block)
2046
Matthew Haddon630c17c2021-10-14 15:05:41 +01002047 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2048 self.ser.addInputTensor(incorrect_iter)
2049 self.ser.addInputTensor(a)
2050 self.ser.addInputTensor(incorrect_acc)
2051 else:
2052 self.ser.addInputTensor(iter)
2053 self.ser.addInputTensor(a)
2054 self.ser.addInputTensor(acc)
2055
Kevin Cheng550ccc52021-03-03 11:21:43 -08002056 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002057
2058 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002059 iter_body_out = self.ser.addIntermediate(
2060 incorrect_iter.shape, incorrect_iter.dtype
2061 )
2062 acc_body_out = self.ser.addIntermediate(
2063 incorrect_acc.shape, incorrect_acc.dtype
2064 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002065 else:
2066 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2067 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2068
Eric Kunzee5e26762020-10-13 16:11:07 -07002069 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2070 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2071 self.ser.addOutputTensor(iter_body_out)
2072 self.ser.addOutputTensor(a)
2073 self.ser.addOutputTensor(acc_body_out)
2074
Les Bell729b0352021-11-24 10:28:21 +00002075 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002076 self.ser,
2077 validator_fcns,
2078 error_name,
2079 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002080 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002081 ):
2082 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002083
Eric Kunzee5e26762020-10-13 16:11:07 -07002084 return acc_out
2085
Luke Hutton261b7b62023-01-10 14:50:31 +00002086 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2087 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2088
2089 input_names = [val.name]
2090 pCount, cCount = op["operands"]
2091 num_operands = pCount + cCount
2092
2093 output_names = [res.name for res in results]
2094 output_dtypes = [res.dtype for res in results]
2095
2096 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2097 self, error_name, input_names, output_names
2098 )
2099
2100 if not TosaErrorValidator.evValidateErrorIfs(
2101 self.ser,
2102 validator_fcns,
2103 error_name,
2104 op=op,
2105 input_shape=val.shape,
2106 input_dtype=val.dtype,
2107 output_dtype=output_dtypes,
2108 result_tensors=results,
2109 input_list=input_names,
2110 output_list=output_names,
2111 num_operands=num_operands,
2112 ):
2113 return None
2114
2115 self.ser.addOperator(op["op"], input_names, output_names)
2116 return results
2117
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002118 def create_filter_lists(
2119 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2120 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002121 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2122 default_test_rank_range = range(1, 5)
2123 if not shapeFilter:
2124 shapeFilter = [None]
2125
2126 # Calculate the filters based on what is requested and what the operator allows
2127 rmin, rmax = op["rank"]
2128 if rankFilter is not None:
2129 cleanRankFilter = []
2130 # Ensure rankFilter values are allowed by operator
2131 for rank in rankFilter:
2132 if rank >= rmin and rank <= rmax:
2133 cleanRankFilter.append(rank)
2134 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002135 # Ensure default behaviour is bounded by default range or by operator,
2136 # whichever is the smaller range of ranks.
2137 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002138 cleanRankFilter = (
2139 opRankRange
2140 if len(opRankRange) <= len(default_test_rank_range)
2141 else default_test_rank_range
2142 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002143 else:
2144 cleanRankFilter = range(rmin, rmax + 1)
2145
2146 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002147
Matthew Haddon1c00b712021-10-01 15:51:03 +01002148 if dtypeFilter is not None:
2149 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002150 # Create list of operator dtypes filtered by requested dtypes
2151 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002152 if dtype in dtypeFilter or (
2153 isinstance(dtype, list) and dtype[0] in dtypeFilter
2154 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002155 cleanDtypeFilter.append(dtype)
2156 else:
2157 cleanDtypeFilter = dtypes
2158
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002159 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002160 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002161 "shapeFilter": shapeFilter,
2162 "rankFilter": cleanRankFilter,
2163 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002164 }
2165 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002166 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002167 if validator is not None:
2168 validator_info = validator(check=False, op=op)
2169 else:
2170 return None
2171
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002172 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002173
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002174 # Set parameters as required
2175 if error_arguments["rank"] is not None:
2176 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002177 else:
2178 rankFilter = cleanRankFilter
2179
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002180 if error_arguments["dtype"] is not None:
2181 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002182 else:
2183 dtypeFilter = cleanDtypeFilter
2184
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002185 if error_arguments["shape"] is not None:
2186 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002187 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002188 shapeFilter = shapeFilter[
2189 :2
2190 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002191
2192 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002193 "shapeFilter": shapeFilter,
2194 "rankFilter": rankFilter,
2195 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002196 }
2197 return filterDict
2198
Kevin Cheng550ccc52021-03-03 11:21:43 -08002199 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002200 self,
2201 opName,
2202 shapeFilter=[None],
2203 rankFilter=None,
2204 dtypeFilter=None,
2205 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002206 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002207
2208 try:
2209 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002210 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002211 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002212
2213 # Initialize a new random number generator
2214 self.rng = np.random.default_rng(self.random_seed)
2215
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002216 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002217
Eric Kunzee5e26762020-10-13 16:11:07 -07002218 # Test list consists of a tuple of:
2219 # (opName, testNameStr, dtype, shapeList, argumentsList)
2220 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002221 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002222 error_if_validators = op["error_if_validators"]
2223 else:
2224 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002225
Matthew Haddon1c00b712021-10-01 15:51:03 +01002226 for validator in error_if_validators:
2227 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002228 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002229 else:
2230 error_name = None
2231
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002232 filterDict = self.create_filter_lists(
2233 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2234 )
2235 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002236 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002237 cleanRankFilter = filterDict["rankFilter"]
2238 cleanDtypeFilter = filterDict["dtypeFilter"]
2239 cleanShapeFilter = filterDict["shapeFilter"]
2240 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002241
2242 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002243 for t in cleanDtypeFilter:
2244 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002245 # Filter out by rank
2246 if shape is not None and len(shape) != r:
2247 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002248 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002249 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002250
Matthew Haddon74567092021-07-16 15:38:20 +01002251 shapeStr = self.shapeStr(shapeList[0])
2252 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002253
Matthew Haddon74567092021-07-16 15:38:20 +01002254 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2255 argList = []
2256 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002257 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002258 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002259 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002260
Matthew Haddon74567092021-07-16 15:38:20 +01002261 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002262 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002263 if argStr:
2264 testStr = "{}_{}_{}_{}".format(
2265 opName, shapeStr, typeStr, argStr
2266 )
2267 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002268 testStr = "{}_{}_{}".format(
2269 opName, shapeStr, typeStr
2270 )
2271 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002272 if argStr:
2273 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2274 opName, error_name, shapeStr, typeStr, argStr
2275 )
2276 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002277 testStr = "{}_ERRORIF_{}_{}_{}".format(
2278 opName, error_name, shapeStr, typeStr
2279 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002280
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002281 testList.append(
2282 (opName, testStr, t, error_name, shapeList, args)
2283 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002284
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002285 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002286 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2287 if "invalid_test_validators" in op:
2288 invalid_test_validators = op["invalid_test_validators"]
2289 clean_testList = []
2290 for test in testList:
2291 for validator_fcn in invalid_test_validators:
2292 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002293 if validator_fcn(
2294 opName=test[0],
2295 input_dtype=test[2],
2296 shapeList=test[4],
2297 args=test[5],
2298 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002299 remove_test = True
2300 if not remove_test:
2301 clean_testList.append(test)
2302 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002303
2304 return testList
2305
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002306 def serializeTest(
2307 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2308 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002309 try:
2310 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002311 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002312 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002313
2314 # Create a serializer
2315 self.createSerializer(opName, testStr)
2316
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002317 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002318 if "error_if_validators" in op:
2319 error_if_validators = op["error_if_validators"]
2320 else:
2321 error_if_validators = None
2322
Kevin Cheng550ccc52021-03-03 11:21:43 -08002323 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002324 num_operands = pCount + cCount
2325
2326 if isinstance(dtype_or_dtypeList, list):
2327 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002328 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002329 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002330 else:
2331 dtypeList = [dtype_or_dtypeList] * (num_operands)
2332
Kevin Cheng93a16282021-08-31 16:14:03 -07002333 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002334 assert (
2335 len(shapeList) == num_operands
2336 ), "shapeList length {} must match number of operands {}".format(
2337 len(shapeList), num_operands
2338 )
2339 assert (
2340 len(dtypeList) == num_operands
2341 ), "dtypeList length {} must match number of operands {}".format(
2342 len(dtypeList), num_operands
2343 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002344
2345 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002346 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002347 except KeyError:
2348 qgen = None
2349
2350 # Build the random tensor operands and the test
2351 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002352
Matthew Haddon1c00b712021-10-01 15:51:03 +01002353 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002354 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002355 else:
2356 qinfo = None
2357
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002358 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002359
Matthew Haddon1c00b712021-10-01 15:51:03 +01002360 try:
2361 if error_if_validators is None:
2362 if qinfo is not None:
2363 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2364 else:
2365 resultName = build_fcn(self, op, *tens, *testArgs)
2366 else:
2367 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002368 resultName = build_fcn(
2369 self,
2370 op,
2371 *tens,
2372 *testArgs,
2373 validator_fcns=error_if_validators,
2374 error_name=error_name,
2375 qinfo=qinfo,
2376 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002377 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002378 resultName = build_fcn(
2379 self,
2380 op,
2381 *tens,
2382 *testArgs,
2383 validator_fcns=error_if_validators,
2384 error_name=error_name,
2385 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002386 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002387 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002388 raise e
2389
Les Bell729b0352021-11-24 10:28:21 +00002390 if resultName:
2391 # The test is valid, serialize it
2392 self.serialize("test")
2393 else:
2394 # The test is not valid
2395 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002396
Eric Kunzee5e26762020-10-13 16:11:07 -07002397 def createDynamicOpLists(self):
2398
Jeremy Johnson00423432022-09-12 17:27:37 +01002399 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2400 # Already created these lists (can occur when class is initialized more than once)
2401 return
2402
Eric Kunzee5e26762020-10-13 16:11:07 -07002403 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002404 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002405
Kevin Cheng1533b852021-09-01 12:51:58 -07002406 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002407 testName = "conv2d_{}x{}".format(k[0], k[1])
2408 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2409 self.TOSA_OP_LIST[testName]["filter"] = k
2410 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002411
Kevin Cheng550ccc52021-03-03 11:21:43 -08002412 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2413 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2414 "depthwise_conv2d_TEMPLATE"
2415 ].copy()
2416 self.TOSA_OP_LIST[testName]["filter"] = k
2417 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002418
Kevin Cheng550ccc52021-03-03 11:21:43 -08002419 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2420 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2421 "transpose_conv2d_TEMPLATE"
2422 ].copy()
2423 self.TOSA_OP_LIST[testName]["filter"] = k
2424 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002425
Kevin Cheng1533b852021-09-01 12:51:58 -07002426 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2427 for k in KERNELS_3D:
2428 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2429 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2430 self.TOSA_OP_LIST[testName]["filter"] = k
2431 self.TOSA_OP_LIST[testName]["template"] = False
2432
Eric Kunzee5e26762020-10-13 16:11:07 -07002433 # Delete any templates after having created any dynamic ops
2434 # This is a two-pass operation because it's bad practice to delete
2435 # keys from dictionaries while iterating
2436 keyList = []
2437 for k in self.TOSA_OP_LIST:
2438 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002439 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002440 keyList.append(k)
2441 continue
2442 except KeyError:
2443 pass
2444
2445 for k in keyList:
2446 del self.TOSA_OP_LIST[k]
2447
2448 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002449 """Fill in default fields for ops if they aren't already specified.
2450 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002451 for op in self.TOSA_OP_LIST:
2452
2453 # Required fields
2454 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002455 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002456 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002457 raise Exception(
2458 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2459 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002460
2461 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002462 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002463 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002464 raise Exception(
2465 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2466 op
2467 )
2468 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002469
2470 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002471 _ = self.TOSA_OP_LIST[op]["types"]
2472 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002473 raise Exception(
2474 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2475 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002476
2477 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002478 _ = self.TOSA_OP_LIST[op]["op"]
2479 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002480 raise Exception(
2481 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2482 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002483
2484 # Put in default rank range, if missing
2485 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002486 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002487 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002488 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002489
2490 # Tensor operator list
2491 # 'op': op name
2492 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002493 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2494 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002495 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2496 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002497 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002498
Kevin Cheng550ccc52021-03-03 11:21:43 -08002499 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002500 TYPE_INT_FP = [
2501 DType.INT8,
2502 DType.INT16,
2503 DType.INT32,
2504 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002505 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002506 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002507 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002508
Kevin Cheng550ccc52021-03-03 11:21:43 -08002509 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002510 TYPE_FI32 = [
2511 DType.FP32,
2512 DType.FP16,
2513 DType.BF16,
2514 DType.INT32,
2515 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002516 TYPE_FIB = [
2517 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002518 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002519 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002520 DType.INT8,
2521 DType.INT16,
2522 DType.INT32,
2523 DType.BOOL,
2524 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002525 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002526
James Ward24dbc422022-10-19 12:20:31 +01002527 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002528
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002529 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002530 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002531 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002532 [DType.INT8, DType.INT8, DType.INT32],
2533 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002534 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002535 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002536 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002537 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002538 ]
2539
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002540 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002541
2542 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002543 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002544 "argmax": {
2545 "op": Op.ARGMAX,
2546 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002547 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002548 "build_fcn": (
2549 build_argmax,
2550 TosaTensorGen.tgBasic,
2551 TosaTensorValuesGen.tvgDefault,
2552 TosaArgGen.agAxis,
2553 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002554 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002555 "error_if_validators": (
2556 TosaErrorValidator.evAxisSmallerZero,
2557 TosaErrorValidator.evAxisLargerRank,
2558 TosaErrorValidator.evArgmaxOutputRankMismatch,
2559 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2560 TosaErrorValidator.evWrongRank,
2561 TosaErrorValidator.evWrongInputType,
2562 TosaErrorValidator.evWrongOutputType,
2563 TosaErrorValidator.evWrongInputList,
2564 TosaErrorValidator.evWrongOutputList,
2565 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002566 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002567 "avg_pool2d": {
2568 "op": Op.AVG_POOL2D,
2569 "operands": (1, 0),
2570 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002571 "build_fcn": (
2572 build_pool2d,
2573 TosaTensorGen.tgNHWC,
2574 TosaTensorValuesGen.tvgDefault,
2575 TosaArgGen.agPooling,
2576 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002577 "qgen": TosaQuantGen.qgUnary,
2578 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002579 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002580 "error_if_validators": (
2581 TosaErrorValidator.evKernelSmallerOne,
2582 TosaErrorValidator.evStrideSmallerOne,
2583 TosaErrorValidator.evPadSmallerZero,
2584 TosaErrorValidator.evWrongRank,
2585 TosaErrorValidator.evWrongInputType,
2586 TosaErrorValidator.evWrongOutputType,
2587 TosaErrorValidator.evWrongInputList,
2588 TosaErrorValidator.evWrongOutputList,
2589 TosaErrorValidator.evInputZeroPointNotZero,
2590 TosaErrorValidator.evOutputZeroPointNotZero,
2591 TosaErrorValidator.evPadLargerEqualKernel,
2592 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002593 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002594 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002595 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002596 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002597 "conv2d_TEMPLATE": {
2598 "op": Op.CONV2D,
2599 "operands": (1, 2),
2600 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002601 "build_fcn": (
2602 build_conv2d,
2603 TosaTensorGen.tgConv2D,
2604 TosaTensorValuesGen.tvgDefault,
2605 TosaArgGen.agConv,
2606 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002607 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002608 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002609 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2610 "error_if_validators": (
2611 TosaErrorValidator.evWrongInputType,
2612 TosaErrorValidator.evWrongOutputType,
2613 TosaErrorValidator.evWrongInputList,
2614 TosaErrorValidator.evWrongOutputList,
2615 TosaErrorValidator.evInputZeroPointNotZero,
2616 TosaErrorValidator.evWeightZeroPointNotZero,
2617 TosaErrorValidator.evPadSmallerZero,
2618 TosaErrorValidator.evStrideSmallerOne,
2619 TosaErrorValidator.evDilationSmallerOne,
2620 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002621 TosaErrorValidator.evConvOutputShapeMismatch,
2622 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002623 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002624 "template": True,
2625 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002626 # Templated operator. Filled in by createDynamicOpLists
2627 "conv3d_TEMPLATE": {
2628 "op": Op.CONV3D,
2629 "operands": (1, 2),
2630 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002631 "build_fcn": (
2632 build_conv3d,
2633 TosaTensorGen.tgConv3D,
2634 TosaTensorValuesGen.tvgDefault,
2635 TosaArgGen.agConv,
2636 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002637 "qgen": TosaQuantGen.qgConv,
2638 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002639 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2640 "error_if_validators": (
2641 TosaErrorValidator.evWrongInputType,
2642 TosaErrorValidator.evWrongOutputType,
2643 TosaErrorValidator.evWrongInputList,
2644 TosaErrorValidator.evWrongOutputList,
2645 TosaErrorValidator.evInputZeroPointNotZero,
2646 TosaErrorValidator.evWeightZeroPointNotZero,
2647 TosaErrorValidator.evPadSmallerZero,
2648 TosaErrorValidator.evStrideSmallerOne,
2649 TosaErrorValidator.evDilationSmallerOne,
2650 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002651 TosaErrorValidator.evConvOutputShapeMismatch,
2652 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002653 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002654 "template": True,
2655 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002656 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002657 "depthwise_conv2d_TEMPLATE": {
2658 "op": Op.DEPTHWISE_CONV2D,
2659 "operands": (1, 2),
2660 "filter": [1, 1],
2661 "rank": (4, 4),
2662 "build_fcn": (
2663 build_depthwise_conv2d,
2664 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002665 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002666 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002667 ),
2668 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002669 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002670 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2671 "error_if_validators": (
2672 TosaErrorValidator.evWrongInputType,
2673 TosaErrorValidator.evWrongOutputType,
2674 TosaErrorValidator.evWrongInputList,
2675 TosaErrorValidator.evWrongOutputList,
2676 TosaErrorValidator.evInputZeroPointNotZero,
2677 TosaErrorValidator.evWeightZeroPointNotZero,
2678 TosaErrorValidator.evPadSmallerZero,
2679 TosaErrorValidator.evStrideSmallerOne,
2680 TosaErrorValidator.evDilationSmallerOne,
2681 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002682 TosaErrorValidator.evConvOutputShapeMismatch,
2683 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002684 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002685 "template": True,
2686 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002687 "fully_connected": {
2688 "op": Op.FULLY_CONNECTED,
2689 "operands": (1, 2),
2690 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002691 "build_fcn": (
2692 build_fully_connected,
2693 TosaTensorGen.tgFullyConnected,
2694 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002695 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002696 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002697 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002698 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002699 "error_if_validators": (
2700 TosaErrorValidator.evInputZeroPointNotZero,
2701 TosaErrorValidator.evWeightZeroPointNotZero,
2702 TosaErrorValidator.evWrongRank,
2703 TosaErrorValidator.evWrongInputType,
2704 TosaErrorValidator.evWrongOutputType,
2705 TosaErrorValidator.evWrongInputList,
2706 TosaErrorValidator.evWrongOutputList,
2707 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002708 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002709 "matmul": {
2710 "op": Op.MATMUL,
2711 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002712 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002713 "build_fcn": (
2714 build_matmul,
2715 TosaTensorGen.tgMatmul,
2716 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002717 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002718 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002719 "qgen": TosaQuantGen.qgMatmul,
2720 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002721 "error_if_validators": (
2722 TosaErrorValidator.evInputZeroPointNotZero,
2723 TosaErrorValidator.evWrongRank,
2724 TosaErrorValidator.evWrongInputType,
2725 TosaErrorValidator.evWrongOutputType,
2726 TosaErrorValidator.evWrongInputList,
2727 TosaErrorValidator.evWrongOutputList,
2728 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002729 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002730 "max_pool2d": {
2731 "op": Op.MAX_POOL2D,
2732 "operands": (1, 0),
2733 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002734 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002735 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002736 TosaTensorGen.tgNHWC,
2737 TosaTensorValuesGen.tvgDefault,
2738 TosaArgGen.agPooling,
2739 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002740 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002741 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002742 "error_if_validators": (
2743 TosaErrorValidator.evKernelSmallerOne,
2744 TosaErrorValidator.evStrideSmallerOne,
2745 TosaErrorValidator.evPadSmallerZero,
2746 TosaErrorValidator.evWrongRank,
2747 TosaErrorValidator.evWrongInputType,
2748 TosaErrorValidator.evWrongOutputType,
2749 TosaErrorValidator.evWrongInputList,
2750 TosaErrorValidator.evWrongOutputList,
2751 TosaErrorValidator.evPadLargerEqualKernel,
2752 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002753 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002754 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002755 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002756 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002757 "transpose_conv2d_TEMPLATE": {
2758 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002759 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002760 "rank": (4, 4),
2761 "build_fcn": (
2762 build_transpose_conv2d,
2763 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002764 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002765 TosaArgGen.agTransposeConv2D,
2766 ),
2767 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002768 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002769 "invalid_test_validators": (
2770 TosaInvalidValidator.ivHeightWidthInvalid,
2771 TosaInvalidValidator.ivNonPositiveOutputShape,
2772 ),
2773 "error_if_validators": (
2774 TosaErrorValidator.evWrongInputType,
2775 TosaErrorValidator.evWrongOutputType,
2776 TosaErrorValidator.evWrongInputList,
2777 TosaErrorValidator.evWrongOutputList,
2778 TosaErrorValidator.evInputZeroPointNotZero,
2779 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002780 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002781 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002782 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002783 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002784 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002785 "template": True,
2786 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002787 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002788 "clamp": {
2789 "op": Op.CLAMP,
2790 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002791 "build_fcn": (
2792 build_clamp,
2793 TosaTensorGen.tgBasic,
2794 TosaTensorValuesGen.tvgDefault,
2795 None,
2796 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002797 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002798 "error_if_validators": (
2799 TosaErrorValidator.evMaxSmallerMin,
2800 TosaErrorValidator.evWrongInputType,
2801 TosaErrorValidator.evWrongOutputType,
2802 TosaErrorValidator.evWrongInputList,
2803 TosaErrorValidator.evWrongOutputList,
2804 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002805 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002806 "sigmoid": {
2807 "op": Op.SIGMOID,
2808 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002809 "build_fcn": (
2810 build_sigmoid,
2811 TosaTensorGen.tgBasic,
2812 TosaTensorValuesGen.tvgDefault,
2813 None,
2814 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002815 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002816 "error_if_validators": (
2817 TosaErrorValidator.evWrongInputType,
2818 TosaErrorValidator.evWrongOutputType,
2819 TosaErrorValidator.evWrongInputList,
2820 TosaErrorValidator.evWrongOutputList,
2821 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002822 },
2823 "tanh": {
2824 "op": Op.TANH,
2825 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002826 "build_fcn": (
2827 build_tanh,
2828 TosaTensorGen.tgBasic,
2829 TosaTensorValuesGen.tvgDefault,
2830 None,
2831 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002832 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002833 "error_if_validators": (
2834 TosaErrorValidator.evWrongInputType,
2835 TosaErrorValidator.evWrongOutputType,
2836 TosaErrorValidator.evWrongInputList,
2837 TosaErrorValidator.evWrongOutputList,
2838 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002839 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002840 # Elementwise Binary Operators
2841 "add": {
2842 "op": Op.ADD,
2843 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002844 "build_fcn": (
2845 build_binary_broadcast,
2846 TosaTensorGen.tgBroadcastFuzz,
2847 TosaTensorValuesGen.tvgAddSub,
2848 None,
2849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002850 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002851 "error_if_validators": (
2852 TosaErrorValidator.evRankMismatch,
2853 TosaErrorValidator.evWrongInputType,
2854 TosaErrorValidator.evWrongOutputType,
2855 TosaErrorValidator.evWrongInputList,
2856 TosaErrorValidator.evWrongOutputList,
2857 TosaErrorValidator.evDimensionMismatch,
2858 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002859 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002860 "arithmetic_right_shift": {
2861 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2862 "operands": (2, 0),
2863 "build_fcn": (
2864 build_arithmetic_right_shift,
2865 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002866 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002867 TosaArgGen.agArithmeticRightShift,
2868 ),
2869 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 "error_if_validators": (
2871 TosaErrorValidator.evRankMismatch,
2872 TosaErrorValidator.evWrongInputType,
2873 TosaErrorValidator.evWrongOutputType,
2874 TosaErrorValidator.evWrongInputList,
2875 TosaErrorValidator.evWrongOutputList,
2876 TosaErrorValidator.evDimensionMismatch,
2877 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002878 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002879 "bitwise_and": {
2880 "op": Op.BITWISE_AND,
2881 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002882 "build_fcn": (
2883 build_binary_broadcast,
2884 TosaTensorGen.tgBroadcastFuzz,
2885 TosaTensorValuesGen.tvgDefault,
2886 None,
2887 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002888 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002889 "error_if_validators": (
2890 TosaErrorValidator.evRankMismatch,
2891 TosaErrorValidator.evWrongInputType,
2892 TosaErrorValidator.evWrongOutputType,
2893 TosaErrorValidator.evWrongInputList,
2894 TosaErrorValidator.evWrongOutputList,
2895 TosaErrorValidator.evDimensionMismatch,
2896 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002897 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002898 "bitwise_or": {
2899 "op": Op.BITWISE_OR,
2900 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002901 "build_fcn": (
2902 build_binary_broadcast,
2903 TosaTensorGen.tgBroadcastFuzz,
2904 TosaTensorValuesGen.tvgDefault,
2905 None,
2906 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002907 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002908 "error_if_validators": (
2909 TosaErrorValidator.evRankMismatch,
2910 TosaErrorValidator.evWrongInputType,
2911 TosaErrorValidator.evWrongOutputType,
2912 TosaErrorValidator.evWrongInputList,
2913 TosaErrorValidator.evWrongOutputList,
2914 TosaErrorValidator.evDimensionMismatch,
2915 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002916 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002917 "bitwise_xor": {
2918 "op": Op.BITWISE_XOR,
2919 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002920 "build_fcn": (
2921 build_binary_broadcast,
2922 TosaTensorGen.tgBroadcastFuzz,
2923 TosaTensorValuesGen.tvgDefault,
2924 None,
2925 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002926 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002927 "error_if_validators": (
2928 TosaErrorValidator.evRankMismatch,
2929 TosaErrorValidator.evWrongInputType,
2930 TosaErrorValidator.evWrongOutputType,
2931 TosaErrorValidator.evWrongInputList,
2932 TosaErrorValidator.evWrongOutputList,
2933 TosaErrorValidator.evDimensionMismatch,
2934 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002935 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002936 "intdiv": {
2937 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002938 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002939 "build_fcn": (
2940 build_binary_broadcast,
2941 TosaTensorGen.tgBroadcastFuzz,
2942 TosaTensorValuesGen.tvgIntDiv,
2943 None,
2944 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002945 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002946 "error_if_validators": (
2947 TosaErrorValidator.evRankMismatch,
2948 TosaErrorValidator.evWrongInputType,
2949 TosaErrorValidator.evWrongOutputType,
2950 TosaErrorValidator.evWrongInputList,
2951 TosaErrorValidator.evWrongOutputList,
2952 TosaErrorValidator.evDimensionMismatch,
2953 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002954 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002955 "logical_and": {
2956 "op": Op.LOGICAL_AND,
2957 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002958 "build_fcn": (
2959 build_binary_broadcast,
2960 TosaTensorGen.tgBroadcastFuzz,
2961 TosaTensorValuesGen.tvgDefault,
2962 None,
2963 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002964 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002965 "error_if_validators": (
2966 TosaErrorValidator.evRankMismatch,
2967 TosaErrorValidator.evWrongInputType,
2968 TosaErrorValidator.evWrongOutputType,
2969 TosaErrorValidator.evWrongInputList,
2970 TosaErrorValidator.evWrongOutputList,
2971 TosaErrorValidator.evDimensionMismatch,
2972 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002973 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002974 "logical_left_shift": {
2975 "op": Op.LOGICAL_LEFT_SHIFT,
2976 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002977 "build_fcn": (
2978 build_binary_broadcast,
2979 TosaTensorGen.tgBroadcastFuzz,
2980 TosaTensorValuesGen.tvgLogicalShift,
2981 None,
2982 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002983 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002984 "error_if_validators": (
2985 TosaErrorValidator.evRankMismatch,
2986 TosaErrorValidator.evWrongInputType,
2987 TosaErrorValidator.evWrongOutputType,
2988 TosaErrorValidator.evWrongInputList,
2989 TosaErrorValidator.evWrongOutputList,
2990 TosaErrorValidator.evDimensionMismatch,
2991 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002992 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002993 "logical_right_shift": {
2994 "op": Op.LOGICAL_RIGHT_SHIFT,
2995 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002996 "build_fcn": (
2997 build_binary_broadcast,
2998 TosaTensorGen.tgBroadcastFuzz,
2999 TosaTensorValuesGen.tvgLogicalShift,
3000 None,
3001 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003002 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003003 "error_if_validators": (
3004 TosaErrorValidator.evRankMismatch,
3005 TosaErrorValidator.evWrongInputType,
3006 TosaErrorValidator.evWrongOutputType,
3007 TosaErrorValidator.evWrongInputList,
3008 TosaErrorValidator.evWrongOutputList,
3009 TosaErrorValidator.evDimensionMismatch,
3010 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003011 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003012 "logical_or": {
3013 "op": Op.LOGICAL_OR,
3014 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003015 "build_fcn": (
3016 build_binary_broadcast,
3017 TosaTensorGen.tgBroadcastFuzz,
3018 TosaTensorValuesGen.tvgDefault,
3019 None,
3020 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003021 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003022 "error_if_validators": (
3023 TosaErrorValidator.evRankMismatch,
3024 TosaErrorValidator.evWrongInputType,
3025 TosaErrorValidator.evWrongOutputType,
3026 TosaErrorValidator.evWrongInputList,
3027 TosaErrorValidator.evWrongOutputList,
3028 TosaErrorValidator.evDimensionMismatch,
3029 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003030 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003031 "logical_xor": {
3032 "op": Op.LOGICAL_XOR,
3033 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003034 "build_fcn": (
3035 build_binary_broadcast,
3036 TosaTensorGen.tgBroadcastFuzz,
3037 TosaTensorValuesGen.tvgDefault,
3038 None,
3039 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003040 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003041 "error_if_validators": (
3042 TosaErrorValidator.evRankMismatch,
3043 TosaErrorValidator.evWrongInputType,
3044 TosaErrorValidator.evWrongOutputType,
3045 TosaErrorValidator.evWrongInputList,
3046 TosaErrorValidator.evWrongOutputList,
3047 TosaErrorValidator.evDimensionMismatch,
3048 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003049 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003050 "maximum": {
3051 "op": Op.MAXIMUM,
3052 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003053 "build_fcn": (
3054 build_binary_broadcast,
3055 TosaTensorGen.tgBroadcastFuzz,
3056 TosaTensorValuesGen.tvgDefault,
3057 None,
3058 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003059 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003060 "error_if_validators": (
3061 TosaErrorValidator.evRankMismatch,
3062 TosaErrorValidator.evWrongInputType,
3063 TosaErrorValidator.evWrongOutputType,
3064 TosaErrorValidator.evWrongInputList,
3065 TosaErrorValidator.evWrongOutputList,
3066 TosaErrorValidator.evDimensionMismatch,
3067 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003068 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003069 "minimum": {
3070 "op": Op.MINIMUM,
3071 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003072 "build_fcn": (
3073 build_binary_broadcast,
3074 TosaTensorGen.tgBroadcastFuzz,
3075 TosaTensorValuesGen.tvgDefault,
3076 None,
3077 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003078 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003079 "error_if_validators": (
3080 TosaErrorValidator.evRankMismatch,
3081 TosaErrorValidator.evWrongInputType,
3082 TosaErrorValidator.evWrongOutputType,
3083 TosaErrorValidator.evWrongInputList,
3084 TosaErrorValidator.evWrongOutputList,
3085 TosaErrorValidator.evDimensionMismatch,
3086 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003087 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003088 "mul": {
3089 "op": Op.MUL,
3090 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003091 "build_fcn": (
3092 build_mul,
3093 TosaTensorGen.tgBroadcastFuzz,
3094 TosaTensorValuesGen.tvgMul,
3095 TosaArgGen.agMul,
3096 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003097 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003098 "error_if_validators": (
3099 TosaErrorValidator.evWrongInputType,
3100 TosaErrorValidator.evWrongOutputType,
3101 TosaErrorValidator.evWrongInputList,
3102 TosaErrorValidator.evWrongOutputList,
3103 TosaErrorValidator.evRankMismatch,
3104 TosaErrorValidator.evDimensionMismatch,
3105 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003106 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003107 "pow": {
3108 "op": Op.POW,
3109 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003110 "build_fcn": (
3111 build_binary_broadcast,
3112 TosaTensorGen.tgBroadcastFuzz,
3113 TosaTensorValuesGen.tvgDefault,
3114 None,
3115 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003116 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003117 "error_if_validators": (
3118 TosaErrorValidator.evRankMismatch,
3119 TosaErrorValidator.evWrongInputType,
3120 TosaErrorValidator.evWrongOutputType,
3121 TosaErrorValidator.evWrongInputList,
3122 TosaErrorValidator.evWrongOutputList,
3123 TosaErrorValidator.evDimensionMismatch,
3124 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003125 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 "sub": {
3127 "op": Op.SUB,
3128 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003129 "build_fcn": (
3130 build_binary_broadcast,
3131 TosaTensorGen.tgBroadcastFuzz,
3132 TosaTensorValuesGen.tvgAddSub,
3133 None,
3134 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003135 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003136 "error_if_validators": (
3137 TosaErrorValidator.evRankMismatch,
3138 TosaErrorValidator.evWrongInputType,
3139 TosaErrorValidator.evWrongOutputType,
3140 TosaErrorValidator.evWrongInputList,
3141 TosaErrorValidator.evWrongOutputList,
3142 TosaErrorValidator.evDimensionMismatch,
3143 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003144 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003145 "table": {
3146 "op": Op.TABLE,
3147 # Use the automatic generation functions to create the input array
3148 # but create the table tensor in the build function, as it may be
3149 # a different type from the input
3150 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003151 "build_fcn": (
3152 build_table,
3153 TosaTensorGen.tgBasic,
3154 TosaTensorValuesGen.tvgDefault,
3155 TosaArgGen.agTable,
3156 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003157 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003158 "error_if_validators": (
3159 TosaErrorValidator.evWrongInputType,
3160 TosaErrorValidator.evWrongOutputType,
3161 TosaErrorValidator.evWrongInputList,
3162 TosaErrorValidator.evWrongOutputList,
3163 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003165 # Elementwise Unary operators
3166 "abs": {
3167 "op": Op.ABS,
3168 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003169 "build_fcn": (
3170 build_unary,
3171 TosaTensorGen.tgBasic,
3172 TosaTensorValuesGen.tvgDefault,
3173 None,
3174 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003175 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003176 "error_if_validators": (
3177 TosaErrorValidator.evWrongInputType,
3178 TosaErrorValidator.evWrongOutputType,
3179 TosaErrorValidator.evWrongInputList,
3180 TosaErrorValidator.evWrongOutputList,
3181 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003183 "bitwise_not": {
3184 "op": Op.BITWISE_NOT,
3185 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003186 "build_fcn": (
3187 build_unary,
3188 TosaTensorGen.tgBasic,
3189 TosaTensorValuesGen.tvgDefault,
3190 None,
3191 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003192 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003193 "error_if_validators": (
3194 TosaErrorValidator.evWrongInputType,
3195 TosaErrorValidator.evWrongOutputType,
3196 TosaErrorValidator.evWrongInputList,
3197 TosaErrorValidator.evWrongOutputList,
3198 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003200 "ceil": {
3201 "op": Op.CEIL,
3202 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003203 "build_fcn": (
3204 build_unary,
3205 TosaTensorGen.tgBasic,
3206 TosaTensorValuesGen.tvgDefault,
3207 None,
3208 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003209 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003210 "error_if_validators": (
3211 TosaErrorValidator.evWrongInputType,
3212 TosaErrorValidator.evWrongOutputType,
3213 TosaErrorValidator.evWrongInputList,
3214 TosaErrorValidator.evWrongOutputList,
3215 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003216 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003217 "clz": {
3218 "op": Op.CLZ,
3219 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 "build_fcn": (
3221 build_unary,
3222 TosaTensorGen.tgBasic,
3223 TosaTensorValuesGen.tvgDefault,
3224 None,
3225 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003226 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003227 "error_if_validators": (
3228 TosaErrorValidator.evWrongInputType,
3229 TosaErrorValidator.evWrongOutputType,
3230 TosaErrorValidator.evWrongInputList,
3231 TosaErrorValidator.evWrongOutputList,
3232 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003233 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003234 "exp": {
3235 "op": Op.EXP,
3236 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003237 "build_fcn": (
3238 build_unary,
3239 TosaTensorGen.tgBasic,
3240 TosaTensorValuesGen.tvgDefault,
3241 None,
3242 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003244 "error_if_validators": (
3245 TosaErrorValidator.evWrongInputType,
3246 TosaErrorValidator.evWrongOutputType,
3247 TosaErrorValidator.evWrongInputList,
3248 TosaErrorValidator.evWrongOutputList,
3249 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003250 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003251 "floor": {
3252 "op": Op.FLOOR,
3253 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003254 "build_fcn": (
3255 build_unary,
3256 TosaTensorGen.tgBasic,
3257 TosaTensorValuesGen.tvgDefault,
3258 None,
3259 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003260 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003261 "error_if_validators": (
3262 TosaErrorValidator.evWrongInputType,
3263 TosaErrorValidator.evWrongOutputType,
3264 TosaErrorValidator.evWrongInputList,
3265 TosaErrorValidator.evWrongOutputList,
3266 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003267 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003268 "log": {
3269 "op": Op.LOG,
3270 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003271 "build_fcn": (
3272 build_unary,
3273 TosaTensorGen.tgBasic,
3274 TosaTensorValuesGen.tvgDefault,
3275 None,
3276 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003277 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003278 "error_if_validators": (
3279 TosaErrorValidator.evWrongInputType,
3280 TosaErrorValidator.evWrongOutputType,
3281 TosaErrorValidator.evWrongInputList,
3282 TosaErrorValidator.evWrongOutputList,
3283 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003284 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003285 "logical_not": {
3286 "op": Op.LOGICAL_NOT,
3287 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003288 "build_fcn": (
3289 build_unary,
3290 TosaTensorGen.tgBasic,
3291 TosaTensorValuesGen.tvgDefault,
3292 None,
3293 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003294 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003295 "error_if_validators": (
3296 TosaErrorValidator.evWrongInputType,
3297 TosaErrorValidator.evWrongOutputType,
3298 TosaErrorValidator.evWrongInputList,
3299 TosaErrorValidator.evWrongOutputList,
3300 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003301 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003302 "negate": {
3303 "op": Op.NEGATE,
3304 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003305 "build_fcn": (
3306 build_unary,
3307 TosaTensorGen.tgBasic,
3308 TosaTensorValuesGen.tvgNegate,
3309 None,
3310 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003311 "qgen": TosaQuantGen.qgUnary,
3312 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003313 "error_if_validators": (
3314 TosaErrorValidator.evInputZeroPointNotZero,
3315 TosaErrorValidator.evOutputZeroPointNotZero,
3316 TosaErrorValidator.evWrongInputType,
3317 TosaErrorValidator.evWrongOutputType,
3318 TosaErrorValidator.evWrongInputList,
3319 TosaErrorValidator.evWrongOutputList,
3320 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003321 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003322 "reciprocal": {
3323 "op": Op.RECIPROCAL,
3324 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003325 "build_fcn": (
3326 build_unary,
3327 TosaTensorGen.tgBasic,
3328 TosaTensorValuesGen.tvgDefault,
3329 None,
3330 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003331 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003332 "error_if_validators": (
3333 TosaErrorValidator.evWrongInputType,
3334 TosaErrorValidator.evWrongOutputType,
3335 TosaErrorValidator.evWrongInputList,
3336 TosaErrorValidator.evWrongOutputList,
3337 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003339 "rsqrt": {
3340 "op": Op.RSQRT,
3341 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 "build_fcn": (
3343 build_unary,
3344 TosaTensorGen.tgBasic,
3345 TosaTensorValuesGen.tvgDefault,
3346 None,
3347 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003348 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003349 "error_if_validators": (
3350 TosaErrorValidator.evWrongInputType,
3351 TosaErrorValidator.evWrongOutputType,
3352 TosaErrorValidator.evWrongInputList,
3353 TosaErrorValidator.evWrongOutputList,
3354 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003355 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003356 # Elementwise Ternary operators
3357 "select": {
3358 "op": Op.SELECT,
3359 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003360 "build_fcn": (
3361 build_select,
3362 TosaTensorGen.tgBroadcastFuzz,
3363 TosaTensorValuesGen.tvgSelect,
3364 None,
3365 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003366 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003367 "error_if_validators": (
3368 TosaErrorValidator.evRankMismatch,
3369 TosaErrorValidator.evWrongInputType,
3370 TosaErrorValidator.evWrongOutputType,
3371 TosaErrorValidator.evWrongInputList,
3372 TosaErrorValidator.evWrongOutputList,
3373 TosaErrorValidator.evDimensionMismatch,
3374 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003375 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003376 # Comparison operators
3377 "equal": {
3378 "op": Op.EQUAL,
3379 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003380 "build_fcn": (
3381 build_comparison,
3382 TosaTensorGen.tgBroadcastFuzz,
3383 TosaTensorValuesGen.tvgEqual,
3384 None,
3385 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003386 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003387 "error_if_validators": (
3388 TosaErrorValidator.evRankMismatch,
3389 TosaErrorValidator.evWrongInputType,
3390 TosaErrorValidator.evWrongOutputType,
3391 TosaErrorValidator.evWrongInputList,
3392 TosaErrorValidator.evWrongOutputList,
3393 TosaErrorValidator.evDimensionMismatch,
3394 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003395 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 "greater_equal": {
3397 "op": Op.GREATER_EQUAL,
3398 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003399 "build_fcn": (
3400 build_comparison,
3401 TosaTensorGen.tgBroadcastFuzz,
3402 TosaTensorValuesGen.tvgDefault,
3403 None,
3404 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003405 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003406 "error_if_validators": (
3407 TosaErrorValidator.evRankMismatch,
3408 TosaErrorValidator.evWrongInputType,
3409 TosaErrorValidator.evWrongOutputType,
3410 TosaErrorValidator.evWrongInputList,
3411 TosaErrorValidator.evWrongOutputList,
3412 TosaErrorValidator.evDimensionMismatch,
3413 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003414 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003415 "greater": {
3416 "op": Op.GREATER,
3417 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003418 "build_fcn": (
3419 build_comparison,
3420 TosaTensorGen.tgBroadcastFuzz,
3421 TosaTensorValuesGen.tvgDefault,
3422 None,
3423 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003424 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003425 "error_if_validators": (
3426 TosaErrorValidator.evRankMismatch,
3427 TosaErrorValidator.evWrongInputType,
3428 TosaErrorValidator.evWrongOutputType,
3429 TosaErrorValidator.evWrongInputList,
3430 TosaErrorValidator.evWrongOutputList,
3431 TosaErrorValidator.evDimensionMismatch,
3432 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 # Reduction operators
3435 "reduce_all": {
3436 "op": Op.REDUCE_ALL,
3437 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003438 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 "build_fcn": (
3440 build_reduce,
3441 TosaTensorGen.tgBasic,
3442 TosaTensorValuesGen.tvgDefault,
3443 TosaArgGen.agAxis,
3444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003446 "error_if_validators": (
3447 TosaErrorValidator.evAxisLargerRank,
3448 TosaErrorValidator.evAxisSmallerZero,
3449 TosaErrorValidator.evShapeOfAxisNotOne,
3450 TosaErrorValidator.evWrongInputType,
3451 TosaErrorValidator.evWrongOutputType,
3452 TosaErrorValidator.evWrongRank,
3453 TosaErrorValidator.evWrongInputList,
3454 TosaErrorValidator.evWrongOutputList,
3455 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003456 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003457 "reduce_any": {
3458 "op": Op.REDUCE_ANY,
3459 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003460 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003461 "build_fcn": (
3462 build_reduce,
3463 TosaTensorGen.tgBasic,
3464 TosaTensorValuesGen.tvgDefault,
3465 TosaArgGen.agAxis,
3466 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003467 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003468 "error_if_validators": (
3469 TosaErrorValidator.evAxisLargerRank,
3470 TosaErrorValidator.evAxisSmallerZero,
3471 TosaErrorValidator.evShapeOfAxisNotOne,
3472 TosaErrorValidator.evWrongInputType,
3473 TosaErrorValidator.evWrongOutputType,
3474 TosaErrorValidator.evWrongRank,
3475 TosaErrorValidator.evWrongInputList,
3476 TosaErrorValidator.evWrongOutputList,
3477 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003478 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003479 "reduce_max": {
3480 "op": Op.REDUCE_MAX,
3481 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003482 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003483 "build_fcn": (
3484 build_reduce,
3485 TosaTensorGen.tgBasic,
3486 TosaTensorValuesGen.tvgDefault,
3487 TosaArgGen.agAxis,
3488 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003489 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003490 "error_if_validators": (
3491 TosaErrorValidator.evAxisLargerRank,
3492 TosaErrorValidator.evAxisSmallerZero,
3493 TosaErrorValidator.evShapeOfAxisNotOne,
3494 TosaErrorValidator.evWrongInputType,
3495 TosaErrorValidator.evWrongOutputType,
3496 TosaErrorValidator.evWrongRank,
3497 TosaErrorValidator.evWrongInputList,
3498 TosaErrorValidator.evWrongOutputList,
3499 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003500 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003501 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003502 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003503 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003504 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003505 "build_fcn": (
3506 build_reduce,
3507 TosaTensorGen.tgBasic,
3508 TosaTensorValuesGen.tvgDefault,
3509 TosaArgGen.agAxis,
3510 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003511 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003512 "error_if_validators": (
3513 TosaErrorValidator.evAxisLargerRank,
3514 TosaErrorValidator.evAxisSmallerZero,
3515 TosaErrorValidator.evShapeOfAxisNotOne,
3516 TosaErrorValidator.evWrongInputType,
3517 TosaErrorValidator.evWrongOutputType,
3518 TosaErrorValidator.evWrongRank,
3519 TosaErrorValidator.evWrongInputList,
3520 TosaErrorValidator.evWrongOutputList,
3521 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003522 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003523 "reduce_product": {
3524 "op": Op.REDUCE_PRODUCT,
3525 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003526 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003527 "build_fcn": (
3528 build_reduce,
3529 TosaTensorGen.tgBasic,
3530 TosaTensorValuesGen.tvgDefault,
3531 TosaArgGen.agAxis,
3532 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003533 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003534 "error_if_validators": (
3535 TosaErrorValidator.evAxisLargerRank,
3536 TosaErrorValidator.evAxisSmallerZero,
3537 TosaErrorValidator.evShapeOfAxisNotOne,
3538 TosaErrorValidator.evWrongInputType,
3539 TosaErrorValidator.evWrongOutputType,
3540 TosaErrorValidator.evWrongRank,
3541 TosaErrorValidator.evWrongInputList,
3542 TosaErrorValidator.evWrongOutputList,
3543 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003544 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003545 "reduce_sum": {
3546 "op": Op.REDUCE_SUM,
3547 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003548 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003549 "build_fcn": (
3550 build_reduce,
3551 TosaTensorGen.tgBasic,
3552 TosaTensorValuesGen.tvgReduceSum,
3553 TosaArgGen.agAxis,
3554 ),
James Ward24dbc422022-10-19 12:20:31 +01003555 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003556 "error_if_validators": (
3557 TosaErrorValidator.evAxisLargerRank,
3558 TosaErrorValidator.evAxisSmallerZero,
3559 TosaErrorValidator.evShapeOfAxisNotOne,
3560 TosaErrorValidator.evWrongInputType,
3561 TosaErrorValidator.evWrongOutputType,
3562 TosaErrorValidator.evWrongRank,
3563 TosaErrorValidator.evWrongInputList,
3564 TosaErrorValidator.evWrongOutputList,
3565 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003566 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003567 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003568 "concat": {
3569 "op": Op.CONCAT,
3570 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003571 "build_fcn": (
3572 build_concat,
3573 TosaTensorGen.tgConcat,
3574 TosaTensorValuesGen.tvgConcat,
3575 TosaArgGen.agAxis,
3576 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003577 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003578 "error_if_validators": (
3579 TosaErrorValidator.evAxisLargerRank,
3580 TosaErrorValidator.evAxisSmallerZero,
3581 TosaErrorValidator.evConcatInputRankMismatch,
3582 TosaErrorValidator.evConcatShapeSumMismatch,
3583 TosaErrorValidator.evConcatInputDimMismatch,
3584 TosaErrorValidator.evWrongInputType,
3585 TosaErrorValidator.evWrongOutputType,
3586 TosaErrorValidator.evWrongOutputList,
3587 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003588 },
3589 "pad": {
3590 "op": Op.PAD,
3591 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003592 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003593 "build_fcn": (
3594 build_pad,
3595 TosaTensorGen.tgBasic,
3596 TosaTensorValuesGen.tvgDefault,
3597 TosaArgGen.agPad,
3598 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003599 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003600 "error_if_validators": (
3601 TosaErrorValidator.evWrongInputType,
3602 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003603 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003604 TosaErrorValidator.evWrongOutputType,
3605 TosaErrorValidator.evWrongInputList,
3606 TosaErrorValidator.evWrongOutputList,
3607 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003608 },
3609 "reshape": {
3610 "op": Op.RESHAPE,
3611 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003612 "build_fcn": (
3613 build_reshape,
3614 TosaTensorGen.tgBasic,
3615 TosaTensorValuesGen.tvgDefault,
3616 TosaArgGen.agReshape,
3617 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003618 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003619 "error_if_validators": (
3620 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3621 TosaErrorValidator.evWrongInputType,
3622 TosaErrorValidator.evWrongOutputType,
3623 TosaErrorValidator.evWrongInputList,
3624 TosaErrorValidator.evWrongOutputList,
3625 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003626 },
3627 "reverse": {
3628 "op": Op.REVERSE,
3629 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003630 "build_fcn": (
3631 build_reverse,
3632 TosaTensorGen.tgBasic,
3633 TosaTensorValuesGen.tvgDefault,
3634 TosaArgGen.agAxis,
3635 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003636 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 "error_if_validators": (
3638 TosaErrorValidator.evAxisSmallerZero,
3639 TosaErrorValidator.evAxisLargerRank,
3640 TosaErrorValidator.evWrongInputType,
3641 TosaErrorValidator.evWrongOutputType,
3642 TosaErrorValidator.evWrongInputList,
3643 TosaErrorValidator.evWrongOutputList,
3644 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003645 },
3646 "slice": {
3647 "op": Op.SLICE,
3648 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003649 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003650 "build_fcn": (
3651 build_slice,
3652 TosaTensorGen.tgBasic,
3653 TosaTensorValuesGen.tvgDefault,
3654 TosaArgGen.agSlice,
3655 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003656 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003657 "error_if_validators": (
3658 TosaErrorValidator.evStartSmallerZero,
3659 TosaErrorValidator.evSizeSmallerEqualZero,
3660 TosaErrorValidator.evStartSizeOutsideBounds,
3661 TosaErrorValidator.evSizeOutputShapeMismatch,
3662 TosaErrorValidator.evInputSizeStartLengthMismatch,
3663 TosaErrorValidator.evWrongRank,
3664 TosaErrorValidator.evWrongInputType,
3665 TosaErrorValidator.evWrongOutputType,
3666 TosaErrorValidator.evWrongInputList,
3667 TosaErrorValidator.evWrongOutputList,
3668 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003669 },
3670 "tile": {
3671 "op": Op.TILE,
3672 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003673 "build_fcn": (
3674 build_tile,
3675 TosaTensorGen.tgBasic,
3676 TosaTensorValuesGen.tvgDefault,
3677 TosaArgGen.agTile,
3678 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003679 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003680 "error_if_validators": (
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003686 },
3687 "transpose": {
3688 "op": Op.TRANSPOSE,
3689 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003690 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003691 "build_fcn": (
3692 build_transpose,
3693 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003694 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003695 TosaArgGen.agTranspose,
3696 ),
3697 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003698 "error_if_validators": (
3699 TosaErrorValidator.evIndexOutsideBounds,
3700 TosaErrorValidator.evIndexUsedTwice,
3701 TosaErrorValidator.evWrongInputType,
3702 TosaErrorValidator.evWrongOutputType,
3703 TosaErrorValidator.evWrongInputList,
3704 TosaErrorValidator.evWrongOutputList,
3705 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003706 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003707 # Data nodes
3708 "const": {
3709 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003710 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003711 "build_fcn": (
3712 build_const,
3713 TosaTensorGen.tgBasic,
3714 TosaTensorValuesGen.tvgDefault,
3715 None,
3716 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003717 "types": TYPE_FIB,
3718 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003719 "identity": {
3720 "op": Op.IDENTITY,
3721 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003722 "build_fcn": (
3723 build_unary,
3724 TosaTensorGen.tgBasic,
3725 TosaTensorValuesGen.tvgDefault,
3726 None,
3727 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003728 "types": TYPE_FIB,
3729 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003730 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003731 "gather": {
3732 "op": Op.GATHER,
3733 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3734 "operands": (1, 0),
3735 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003736 "build_fcn": (
3737 build_gather,
3738 TosaTensorGen.tgBasic,
3739 TosaTensorValuesGen.tvgDefault,
3740 None,
3741 ),
James Ward24dbc422022-10-19 12:20:31 +01003742 "types": (
3743 DType.INT8,
3744 DType.INT16,
3745 DType.INT32,
3746 DType.FP16,
3747 DType.BF16,
3748 DType.FP32,
3749 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003750 "error_if_validators": (
3751 TosaErrorValidator.evWrongInputType,
3752 TosaErrorValidator.evWrongOutputType,
3753 TosaErrorValidator.evWrongInputList,
3754 TosaErrorValidator.evWrongOutputList,
3755 TosaErrorValidator.evWrongRank,
3756 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003757 },
3758 "scatter": {
3759 "op": Op.SCATTER,
3760 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003761 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003762 "operands": (2, 0),
3763 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003764 "build_fcn": (
3765 build_scatter,
3766 TosaTensorGen.tgScatter,
3767 TosaTensorValuesGen.tvgDefault,
3768 None,
3769 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003770 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003771 "error_if_validators": (
3772 TosaErrorValidator.evWrongInputType,
3773 TosaErrorValidator.evWrongOutputType,
3774 TosaErrorValidator.evWrongInputList,
3775 TosaErrorValidator.evWrongOutputList,
3776 TosaErrorValidator.evWrongRank,
3777 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003778 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003779 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003780 "resize": {
3781 "op": Op.RESIZE,
3782 "operands": (1, 0),
3783 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 "build_fcn": (
3785 build_resize,
3786 TosaTensorGen.tgNHWC,
3787 TosaTensorValuesGen.tvgDefault,
3788 TosaArgGen.agResize,
3789 ),
James Ward24dbc422022-10-19 12:20:31 +01003790 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 "invalid_test_validators": (
3792 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003793 ),
3794 "error_if_validators": (
3795 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003796 TosaErrorValidator.evScaleSmallerEqualZero,
3797 TosaErrorValidator.evScaleNLargerMax,
3798 TosaErrorValidator.evScaleDLargerMax,
3799 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003800 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003801 TosaErrorValidator.evBorderSmallerMin,
3802 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003803 TosaErrorValidator.evWrongInputType,
3804 TosaErrorValidator.evWrongOutputType,
3805 TosaErrorValidator.evWrongRank,
3806 TosaErrorValidator.evWrongInputList,
3807 TosaErrorValidator.evWrongOutputList,
3808 TosaErrorValidator.evBatchMismatch,
3809 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003810 TosaErrorValidator.evResizeOutputShapeMismatch,
3811 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003812 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003813 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003814 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003815 "cast": {
3816 "op": Op.CAST,
3817 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003818 "build_fcn": (
3819 build_cast,
3820 TosaTensorGen.tgBasic,
3821 TosaTensorValuesGen.tvgDefault,
3822 TosaArgGen.agCast,
3823 ),
James Ward8b390432022-08-12 20:48:56 +01003824 "types": (
3825 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003826 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003827 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003828 DType.INT8,
3829 DType.INT16,
3830 DType.INT32,
3831 DType.BOOL,
3832 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003833 "error_if_validators": (
3834 TosaErrorValidator.evWrongInputType,
3835 TosaErrorValidator.evWrongOutputType,
3836 TosaErrorValidator.evWrongInputList,
3837 TosaErrorValidator.evWrongOutputList,
3838 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003839 },
3840 "rescale": {
3841 "op": Op.RESCALE,
3842 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003843 "build_fcn": (
3844 build_rescale,
3845 TosaTensorGen.tgBasic,
3846 TosaTensorValuesGen.tvgDefault,
3847 TosaArgGen.agRescale,
3848 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003849 "types": [
3850 DType.UINT8,
3851 DType.INT8,
3852 DType.INT16,
3853 DType.INT32,
3854 DType.INT48,
3855 DType.UINT16,
3856 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003857 "error_if_validators": (
3858 TosaErrorValidator.evInputZeroPointNotZero,
3859 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003860 TosaErrorValidator.evU16InputZeroPointNotValid,
3861 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003862 TosaErrorValidator.evScaleTrue,
3863 TosaErrorValidator.evScaleNotTrue,
3864 TosaErrorValidator.evWrongInputType,
3865 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003866 TosaErrorValidator.evWrongInputList,
3867 TosaErrorValidator.evWrongOutputList,
3868 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003869 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003870 # Custom
3871 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003872 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003873 # Two varients of cond_if, one that generates one of two constant tensors (no
3874 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3875 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003876 "cond_if_const": {
3877 "op": Op.COND_IF,
3878 "operands": (0, 2),
3879 "build_fcn": (
3880 build_cond_if_const,
3881 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003882 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003883 TosaArgGen.agCondIf,
3884 ),
3885 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003886 "error_if_validators": (
3887 TosaErrorValidator.evOutputListThenGraphMismatch,
3888 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003889 TosaErrorValidator.evCondIfCondNotMatchingBool,
3890 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003891 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003892 },
3893 "cond_if_binary": {
3894 "op": Op.COND_IF,
3895 "operands": (2, 0),
3896 "build_fcn": (
3897 build_cond_if_binary,
3898 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003899 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003900 TosaArgGen.agCondIf,
3901 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003902 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003903 "error_if_validators": (
3904 TosaErrorValidator.evInputListThenGraphMismatch,
3905 TosaErrorValidator.evInputListElseGraphMismatch,
3906 TosaErrorValidator.evOutputListThenGraphMismatch,
3907 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003908 TosaErrorValidator.evCondIfCondNotMatchingBool,
3909 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003910 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003911 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003912 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003913 "while_loop": {
3914 "op": Op.WHILE_LOOP,
3915 "operands": (0, 1),
3916 "build_fcn": (
3917 build_while_loop,
3918 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003919 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003920 TosaArgGen.agWhileLoop,
3921 ),
3922 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003923 "error_if_validators": (
3924 TosaErrorValidator.evInputListOutputListMismatch,
3925 TosaErrorValidator.evInputListCondGraphMismatch,
3926 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3927 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3928 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003929 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003930 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003931 },
Luke Hutton261b7b62023-01-10 14:50:31 +00003932 "rfft2d": {
3933 "op": Op.RFFT2D,
3934 "operands": (1, 0),
3935 "rank": (3, 3),
3936 "build_fcn": (
3937 build_rfft2d,
3938 TosaTensorGen.tgRFFT2d,
3939 TosaTensorValuesGen.tvgDefault,
3940 TosaArgGen.agNone,
3941 ),
3942 "types": [DType.FP32],
3943 "error_if_validators": (
3944 TosaErrorValidator.evWrongInputType,
3945 TosaErrorValidator.evWrongOutputType,
3946 TosaErrorValidator.evWrongInputList,
3947 TosaErrorValidator.evWrongOutputList,
3948 TosaErrorValidator.evWrongRank,
3949 TosaErrorValidator.evBatchMismatch,
3950 TosaErrorValidator.evKernelNotPowerOfTwo,
3951 ),
3952 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003953 }
3954
Kevin Cheng550ccc52021-03-03 11:21:43 -08003955
Eric Kunzee5e26762020-10-13 16:11:07 -07003956class OutputShaper:
3957 # Methods in this class compute the expected output shape and datatype
3958 # for common classes of operations
3959 def __init__(self):
3960 pass
3961
3962 # These methods return arguments that can be used for
3963 # creating a new output tensor
3964 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003965 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3966 if error_name != ErrorIf.RankMismatch:
3967 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003968 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003969
3970 shape = []
3971 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003972 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003973 shape.append(b.shape[i])
3974 else:
3975 shape.append(a.shape[i])
3976
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003977 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003978 all_dtypes = [
3979 DType.INT8,
3980 DType.INT16,
3981 DType.INT32,
3982 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01003983 DType.FP16,
3984 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003985 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003986 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003987 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3988 outputDType = rng.choice(wrong_dtypes)
3989 else:
3990 outputDType = a.dtype
3991
3992 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003993
3994 @staticmethod
3995 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003996 assert len(a.shape) == len(b.shape)
3997 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003998
3999 shape = []
4000 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004001 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004002 shape.append(a.shape[i])
4003
Kevin Cheng550ccc52021-03-03 11:21:43 -08004004 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004005
4006 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004007 def unaryOp(ser, rng, a, error_name=None):
4008 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004009 all_dtypes = [
4010 DType.INT8,
4011 DType.INT16,
4012 DType.INT32,
4013 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004014 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004015 DType.FP16,
4016 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004017 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004018 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4019 outputDType = rng.choice(wrong_dtypes)
4020 else:
4021 outputDType = a.dtype
4022
4023 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004024
4025 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004026 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004027 if error_name != ErrorIf.RankMismatch:
4028 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004029 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004030
4031 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004032 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004033 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004034 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4035 else:
4036 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004037
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004038 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004039 all_dtypes = [
4040 DType.INT8,
4041 DType.INT16,
4042 DType.INT32,
4043 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004044 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004045 DType.FP16,
4046 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004047 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004048 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4049 outputDType = rng.choice(wrong_dtypes)
4050 else:
4051 outputDType = a.dtype
4052
4053 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004054
4055 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004056 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004057 if error_name != ErrorIf.RankMismatch:
4058 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004059 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004060
4061 # Do broadcast
4062 shape = []
4063 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004064 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004065 shape.append(b.shape[i])
4066 else:
4067 shape.append(a.shape[i])
4068
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004069 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004070 wrong_dtypes = [
4071 DType.INT8,
4072 DType.INT16,
4073 DType.INT32,
4074 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004075 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004076 DType.FP16,
4077 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004078 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004079 outputDType = rng.choice(wrong_dtypes)
4080 else:
4081 outputDType = DType.BOOL
4082
4083 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004084
4085 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004086 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004087 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004088 if error_name not in [
4089 ErrorIf.AxisSmallerZero,
4090 ErrorIf.AxisLargerRank,
4091 ErrorIf.ShapeOfAxisNotOne,
4092 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004093 shape[axis] = 1
4094 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4095 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004096
Matthew Haddond6ce7252021-09-29 15:35:44 +01004097 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004098 all_dtypes = [
4099 DType.INT8,
4100 DType.INT16,
4101 DType.INT32,
4102 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004103 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004104 DType.FP16,
4105 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004106 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004107 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4108 outputDType = rng.choice(wrong_dtypes)
4109 else:
4110 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004111
Matthew Haddond6ce7252021-09-29 15:35:44 +01004112 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004113
4114 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004115 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004116 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004117
4118 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4119 del shape[axis]
4120
4121 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4122 remove = rng.choice([True, False])
4123 if remove and len(shape) > 1:
4124 del shape[0]
4125 else:
4126 shape.append(1)
4127 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4128 for i in range(len(shape)):
4129 shape[i] = shape[i] + rng.integers(1, 10)
4130
4131 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004132 all_dtypes = [
4133 DType.INT8,
4134 DType.INT16,
4135 DType.INT32,
4136 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004137 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004138 DType.FP16,
4139 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004140 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004141 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4142 outputDType = rng.choice(wrong_dtypes)
4143 else:
4144 outputDType = DType.INT32
4145
4146 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004147
4148 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004149 def conv2dOp(
4150 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4151 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004152
4153 # IFM: NHWC
4154 # Filter: OHWI
4155 # OFM: NHWC
4156
Kevin Cheng550ccc52021-03-03 11:21:43 -08004157 h = (
4158 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004159 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004160 + padding[0]
4161 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004162 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004163 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004164
Kevin Cheng550ccc52021-03-03 11:21:43 -08004165 w = (
4166 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004167 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004168 + padding[2]
4169 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004170 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004171 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004172
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004173 if error_name == ErrorIf.ConvOutputShapeMismatch:
4174 choices = [1, 2, 3]
4175 change = rng.choice(choices)
4176 # increment in multiples of stride to not hit non-integer error case
4177 if change in [1, 3]:
4178 h = h + (rng.choice(choices) * strides[0])
4179 if change in [2, 3]:
4180 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004181
Eric Kunzee5e26762020-10-13 16:11:07 -07004182 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4183
James Ward8b390432022-08-12 20:48:56 +01004184 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004185 # Pick some potentially correct output dtype if input type is incorrect
4186 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004187 else:
James Ward8b390432022-08-12 20:48:56 +01004188 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004189
4190 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004191 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004192 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004193 else:
4194 excludes = [out_dtype]
4195 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004196 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004197
Kevin Cheng550ccc52021-03-03 11:21:43 -08004198 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004199
4200 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004201 def conv3dOp(
4202 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4203 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004204
4205 # IFM: NDHWC
4206 # Filter: ODHWI
4207 # OFM: NDHWC
4208
4209 d = (
4210 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004211 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004212 + padding[0]
4213 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004214 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004215 ) // strides[0] + 1
4216
4217 h = (
4218 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004219 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004220 + padding[2]
4221 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004222 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004223 ) // strides[1] + 1
4224
4225 w = (
4226 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004227 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004228 + padding[4]
4229 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004230 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004231 ) // strides[2] + 1
4232
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004233 if error_name == ErrorIf.ConvOutputShapeMismatch:
4234 choices = [1, 2, 3, 4]
4235 change = rng.choice(choices)
4236 # increment in multiples of stride to not hit non-integer error case
4237 if change in [1, 4]:
4238 d = d + (rng.choice(choices) * strides[0])
4239 if change in [2, 4]:
4240 h = h + (rng.choice(choices) * strides[1])
4241 if change in [3, 4]:
4242 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004243
Kevin Cheng1533b852021-09-01 12:51:58 -07004244 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4245
James Ward8b390432022-08-12 20:48:56 +01004246 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004247 # Pick some potentially correct output dtype if input type is incorrect
4248 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004249 else:
James Ward8b390432022-08-12 20:48:56 +01004250 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004251
4252 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004253 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004254 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004255 else:
4256 excludes = [out_dtype]
4257 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004258 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004259
4260 return ser.addOutput(ofm_shape, out_dtype)
4261
4262 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004263 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004264 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004265 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004266 # IFM: NHWC
4267 # Filter: HWCM
4268 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004269
Kevin Cheng550ccc52021-03-03 11:21:43 -08004270 h = (
4271 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004272 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004273 + padding[0]
4274 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004275 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004276 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004277
Kevin Cheng550ccc52021-03-03 11:21:43 -08004278 w = (
4279 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004280 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004281 + padding[2]
4282 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004283 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004284 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004285
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004286 if error_name == ErrorIf.ConvOutputShapeMismatch:
4287 choices = [1, 2, 3]
4288 change = rng.choice(choices)
4289 # increment in multiples of stride to not hit non-integer error case
4290 if change in [1, 3]:
4291 h = h + (rng.choice(choices) * strides[0])
4292 if change in [2, 3]:
4293 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004294
Eric Kunzee5e26762020-10-13 16:11:07 -07004295 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4296
James Ward8b390432022-08-12 20:48:56 +01004297 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004298 # Pick some potentially correct output dtype if input type is incorrect
4299 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004300 else:
James Ward8b390432022-08-12 20:48:56 +01004301 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004302
4303 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004304 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004305 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004306 else:
4307 excludes = [out_dtype]
4308 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004309 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004310
Kevin Cheng550ccc52021-03-03 11:21:43 -08004311 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004312
4313 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004314 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004315 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004316 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004317 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004318 h = 1
4319 w = 1
4320 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004321 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4322 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004323
4324 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004325 choices = [1, 2, 3]
4326 change = rng.choice(choices)
4327 # increment in multiples of stride to not hit non-integer error case
4328 if change in [1, 3]:
4329 h = h + (rng.choice(choices) * stride[0])
4330 if change in [2, 3]:
4331 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004332 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004333
4334 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004335 all_dtypes = [
4336 DType.INT8,
4337 DType.INT16,
4338 DType.INT32,
4339 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004340 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004341 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004342 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004343 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004344 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4345 outputDType = rng.choice(wrong_dtypes)
4346 else:
4347 outputDType = ifm.dtype
4348
4349 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004350
4351 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004352 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004353 # input: N, IC
4354 # filter: OC, IC
4355 # output: N, OC
4356
4357 output_shape = [input.shape[0], filter.shape[0]]
4358
James Ward8b390432022-08-12 20:48:56 +01004359 # Validated in arg_gen (also invalidated for ErrorIf)
4360 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004361
Kevin Cheng550ccc52021-03-03 11:21:43 -08004362 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004363
4364 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004365 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004366 # a: N, H, C
4367 # b: N, C, W
4368 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004369
Kevin Cheng2d60f002021-06-09 14:18:32 -07004370 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004371
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004372 if error_name == ErrorIf.WrongOutputType:
4373 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004374 incorrect_types = (
4375 DType.INT4,
4376 DType.INT8,
4377 DType.INT16,
4378 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004379 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004380 DType.FP16,
4381 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004382 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004383 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004384 incorrect_types = (
4385 DType.INT4,
4386 DType.INT8,
4387 DType.INT16,
4388 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004389 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004390 DType.FP16,
4391 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004392 )
James Ward24dbc422022-10-19 12:20:31 +01004393 elif (
4394 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4395 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004396 incorrect_types = (
4397 DType.INT4,
4398 DType.INT8,
4399 DType.INT16,
4400 DType.INT32,
4401 DType.INT48,
4402 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004403 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004404 elif error_name == ErrorIf.WrongInputType:
4405 # Pick some potentially correct output dtype if input type is incorrect
4406 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004407 else:
James Ward8b390432022-08-12 20:48:56 +01004408 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004409
Kevin Cheng550ccc52021-03-03 11:21:43 -08004410 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004411
4412 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004413 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004414 input1 = a[0]
4415 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004416
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004417 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004418 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004419 if not (
4420 # unable to concat tensors of different ranks
4421 error_name == ErrorIf.ConcatInputRankMismatch
4422 # unable to concat tensors along an invalid axis
4423 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004424 ):
4425 for tensor in remaining_inputs:
4426 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004427
Matthew Haddon01c359d2021-10-15 16:30:48 +01004428 if error_name == ErrorIf.ConcatShapeSumMismatch:
4429 output_shape[axis] += rng.integers(5, 10)
4430
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004431 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004432 all_dtypes = {
4433 DType.INT8,
4434 DType.INT16,
4435 DType.INT32,
4436 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004437 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004438 DType.FP16,
4439 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004440 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004441 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4442 outputDType = rng.choice(wrong_dtypes)
4443 else:
4444 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004445
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004446 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004447
4448 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004449 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004450
4451 output_shape = a.shape.copy()
4452
4453 for i in range(len(output_shape)):
4454 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4455
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004456 if error_name == ErrorIf.PadOutputShapeMismatch:
4457 bad_dim = rng.choice(range(len(output_shape)))
4458 output_shape[bad_dim] -= rng.choice([1, 2])
4459
Matthew Haddone807aae2021-10-11 18:12:58 +01004460 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004461 all_dtypes = [
4462 DType.INT8,
4463 DType.INT16,
4464 DType.INT32,
4465 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004466 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004467 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004468 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004469 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004470 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4471 outputDType = rng.choice(wrong_dtypes)
4472 else:
4473 outputDType = a.dtype
4474
4475 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004476
4477 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004478 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004479 output_shape = shape.copy()
4480
Matthew Haddone807aae2021-10-11 18:12:58 +01004481 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4482 for i in range(len(output_shape)):
4483 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4484
4485 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004486 all_dtypes = [
4487 DType.INT8,
4488 DType.INT16,
4489 DType.INT32,
4490 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004491 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004492 DType.FP16,
4493 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004494 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004495 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4496 outputDType = rng.choice(wrong_dtypes)
4497 else:
4498 outputDType = a.dtype
4499
4500 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004501
4502 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004503 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004504
Matthew Haddone807aae2021-10-11 18:12:58 +01004505 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004506 all_dtypes = [
4507 DType.INT8,
4508 DType.INT16,
4509 DType.INT32,
4510 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004511 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004512 DType.FP16,
4513 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004514 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004515 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4516 outputDType = rng.choice(wrong_dtypes)
4517 else:
4518 outputDType = a.dtype
4519
4520 if error_name == ErrorIf.SizeOutputShapeMismatch:
4521 output_shape = size.copy()
4522 for index in range(len(output_shape)):
4523 if output_shape[index] <= 2:
4524 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4525 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004526 output_shape[index] = output_shape[index] + rng.choice(
4527 [-2, -1, 1, 2]
4528 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004529 else:
4530 output_shape = size.copy()
4531
4532 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004533
4534 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004535 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004536
4537 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004538 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004539
4540 for i in range(len(output_shape)):
4541 output_shape[i] = a.shape[i] * multiples[i]
4542
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004543 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004544 all_dtypes = [
4545 DType.INT8,
4546 DType.INT16,
4547 DType.INT32,
4548 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004549 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004550 DType.FP16,
4551 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004552 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004553 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4554 outputDType = rng.choice(wrong_dtypes)
4555 else:
4556 outputDType = a.dtype
4557
4558 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004559
4560 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004561 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004562 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004563
Kevin Cheng550ccc52021-03-03 11:21:43 -08004564 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004565
Matthew Haddone807aae2021-10-11 18:12:58 +01004566 if error_name == ErrorIf.IndexOutsideBounds:
4567 for i in range(len(output_shape)):
4568 output_shape[i] = a.shape[0]
4569 else:
4570 for i in range(len(output_shape)):
4571 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004572
Matthew Haddone807aae2021-10-11 18:12:58 +01004573 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004574 all_dtypes = [
4575 DType.INT8,
4576 DType.INT16,
4577 DType.INT32,
4578 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004579 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004580 DType.FP16,
4581 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004582 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004583 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4584 outputDType = rng.choice(wrong_dtypes)
4585 else:
4586 outputDType = a.dtype
4587
4588 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004589
4590 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004591 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004592 if error_name != ErrorIf.WrongRank:
4593 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004594 assert len(indices.shape) == 2
4595 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004596
Kevin Cheng77d0f762020-11-24 10:26:32 -08004597 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4598
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004599 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004600 all_dtypes = [
4601 DType.INT8,
4602 DType.INT16,
4603 DType.INT32,
4604 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004605 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004606 DType.FP16,
4607 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004608 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004609 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4610 outputDType = rng.choice(wrong_dtypes)
4611 else:
4612 outputDType = values.dtype
4613
4614 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004615
4616 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004617 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004618 if error_name != ErrorIf.WrongRank:
4619 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004620 assert len(indices.shape) == 2
4621 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004622 assert values_in.shape[0] == indices.shape[0] # N
4623 assert input.shape[1] == indices.shape[1] # W
4624 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004625
4626 output_shape = values_in.shape
4627
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004628 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004629 all_dtypes = [
4630 DType.INT8,
4631 DType.INT16,
4632 DType.INT32,
4633 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004634 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004635 DType.FP16,
4636 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004637 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004638 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4639 outputDType = rng.choice(wrong_dtypes)
4640 else:
4641 outputDType = values_in.dtype
4642
4643 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004644
4645 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004646 def tableOp(ser, rng, input, error_name=None):
4647 # Same shape as the input, dtype dependent on input dtype
4648 if error_name != ErrorIf.WrongInputType:
4649 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004650 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004651 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004652 wrong_dtypes = [
4653 DType.INT8,
4654 DType.INT16,
4655 DType.INT32,
4656 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004657 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004658 DType.FP16,
4659 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004660 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004661 wrong_dtypes.remove(output_dtype)
4662 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004663 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004664
4665 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004666 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004667 serializer,
4668 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004669 input,
4670 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004671 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004672 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004673 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004674 input_dtype,
4675 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004676 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004677 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004678 # Calculate OH, OW
4679 scale_y_n = scale[0]
4680 scale_y_d = scale[1]
4681 scale_x_n = scale[2]
4682 scale_x_d = scale[3]
4683 if error_name == ErrorIf.ScaleSmallerEqualZero:
4684 scale_y_n = max(scale_y_n, 1)
4685 scale_y_d = max(scale_y_d, 1)
4686 scale_x_n = max(scale_x_n, 1)
4687 scale_x_d = max(scale_x_d, 1)
4688
4689 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4690 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4691
4692 if error_name is not None:
4693 # Make sure the output tensor is valid, which can occur when
4694 # scale, offset or border have been changed for ERROR_IFs
4695 oh = max(oh, 1)
4696 ow = max(ow, 1)
4697 if error_name != ErrorIf.MaxDimExceeded:
4698 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4699 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4700
4701 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4702 choices = [1, 2, 3]
4703 change = rng.choice(choices)
4704 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4705 if change in [1, 3]:
4706 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4707 oh -= scale_y_d
4708 assert oh > 0 # Should have been caught in agResize
4709 else:
4710 oh += scale_y_d
4711 if change in [2, 3]:
4712 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4713 ow -= scale_x_d
4714 assert ow > 0 # Should have been caught in agResize
4715 else:
4716 ow += scale_x_d
4717
Matthew Haddon848efb42021-09-09 12:30:53 +01004718 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004719 output_dims = [
4720 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004721 oh,
4722 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004723 input.shape[0],
4724 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004725 elif error_name == ErrorIf.BatchMismatch:
4726 output_dims = [
4727 input.shape[0] + rng.integers(1, 10),
4728 oh,
4729 ow,
4730 input.shape[3],
4731 ]
4732 elif error_name == ErrorIf.ChannelMismatch:
4733 output_dims = [
4734 input.shape[0],
4735 oh,
4736 ow,
4737 input.shape[3] + rng.integers(1, 10),
4738 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004739 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004740 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004741
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004742 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004743
4744 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004745 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004746 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004747
4748 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004749 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004750 if error_name == ErrorIf.ConvOutputShapeMismatch:
4751 choices = [1, 2, 3]
4752 change = rng.choice(choices)
4753 if change in [1, 3]:
4754 output_shape[1] = output_shape[1] + rng.choice(choices)
4755 if change in [2, 3]:
4756 output_shape[2] = output_shape[2] + rng.choice(choices)
4757
James Ward8b390432022-08-12 20:48:56 +01004758 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004759 # Pick some potentially correct output dtype if input type is incorrect
4760 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004761 else:
James Ward8b390432022-08-12 20:48:56 +01004762 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004763
4764 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004765 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004766 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004767 else:
4768 excludes = [out_dtype]
4769 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004770 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004771
Kevin Cheng550ccc52021-03-03 11:21:43 -08004772 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00004773
4774 @staticmethod
4775 def rfft2dOp(serializer, rng, value, error_name=None):
4776 outputs = []
4777
4778 input_shape = value.shape
4779 if error_name != ErrorIf.WrongRank:
4780 assert len(input_shape) == 3
4781
4782 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
4783
4784 output_dtype = value.dtype
4785 if error_name == ErrorIf.WrongOutputType:
4786 excludes = [DType.FP32]
4787 wrong_dtypes = list(usableDTypes(excludes=excludes))
4788 output_dtype = rng.choice(wrong_dtypes)
4789 elif error_name == ErrorIf.BatchMismatch:
4790 incorrect_batch = input_shape[0] + rng.integers(1, 10)
4791 output_shape = [incorrect_batch, *input_shape[1:]]
4792
4793 outputs.append(serializer.addOutput(output_shape, output_dtype))
4794 outputs.append(serializer.addOutput(output_shape, output_dtype))
4795 return outputs