blob: 7fef9422c446cfb1c13c8894dd9ffab3d203da99 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Luke Huttona4e48ca2023-02-22 11:53:48 +000017from generator.tosa_utils import get_rank_mismatch_shape
Jeremy Johnson05c711e2022-12-12 18:00:41 +000018from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010019from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010021from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010027 # Maximum rank of tensor supported by test generator.
28 TOSA_TENSOR_MAX_RANK = 6
29
Eric Kunzee5e26762020-10-13 16:11:07 -070030 def __init__(self, args):
31 self.args = args
32 self.basePath = args.output_dir
33 self.random_seed = args.random_seed
34 self.ser = None
35 self.rng = np.random.default_rng(self.random_seed)
36 self.createDynamicOpLists()
37 self.initOpListDefaults()
38 self.quantGen = TosaQuantGen()
39 # Force makeShape to do a specific starting shape
40 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010041 # Work out floating point range
42 self.random_fp_low = min(args.tensor_fp_value_range)
43 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070044
45 def createSerializer(self, opName, testPath):
46 self.testPath = os.path.join(opName, testPath)
47
48 fullPath = os.path.join(self.basePath, self.testPath)
49 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010050 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070051
52 def getSerializer(self):
53 return self.ser
54
55 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080056 with open(
57 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
58 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070059 fd.write(self.ser.serialize())
60
Kevin Cheng550ccc52021-03-03 11:21:43 -080061 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
62 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070063
Matthew Haddon74567092021-07-16 15:38:20 +010064 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000065 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010066 seed = self.random_seed + 1
67 self.rng = np.random.default_rng(seed)
68
Eric Kunzee5e26762020-10-13 16:11:07 -070069 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070070 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070071 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070072 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070073 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070074 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070075 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010076 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
77 elif dtype == DType.UINT8:
78 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070079 elif dtype == DType.INT16:
80 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010081 elif dtype == DType.UINT16:
82 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070083 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080084 return np.int32(
85 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
86 )
Eric Kunzee5e26762020-10-13 16:11:07 -070087 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080088 return np.int64(
89 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
90 )
James Ward8b390432022-08-12 20:48:56 +010091 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010092 return np.float16(
93 self.rng.uniform(
94 low=self.random_fp_low, high=self.random_fp_high, size=shape
95 )
96 )
James Ward24dbc422022-10-19 12:20:31 +010097 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010098 f32_tensor = np.float32(
99 self.rng.uniform(
100 low=self.random_fp_low, high=self.random_fp_high, size=shape
101 )
102 )
James Ward24dbc422022-10-19 12:20:31 +0100103 # Floor the last 16 bits of each f32 value
104 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100105 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100106 return np.float32(
107 self.rng.uniform(
108 low=self.random_fp_low, high=self.random_fp_high, size=shape
109 )
110 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700111 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800112 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700113
Kevin Cheng989cb052021-04-28 16:29:44 -0700114 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700115 placeholders = []
116
Kevin Cheng989cb052021-04-28 16:29:44 -0700117 assert len(shape_list) == len(dtype_list)
118
119 for idx, shape in enumerate(shape_list):
120 arr = self.getRandTensor(shape, dtype_list[idx])
121 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700122
123 return placeholders
124
Kevin Cheng989cb052021-04-28 16:29:44 -0700125 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700126 consts = []
127
Kevin Cheng989cb052021-04-28 16:29:44 -0700128 assert len(shape_list) == len(dtype_list)
129
130 for idx, shape in enumerate(shape_list):
131 arr = self.getRandTensor(shape, dtype_list[idx])
132 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700133
134 return consts
135
136 def makeShape(self, rank):
137 if self.targetted_shape:
138 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800139 return np.int32(
140 self.rng.integers(
141 low=self.args.tensor_shape_range[0],
142 high=self.args.tensor_shape_range[1],
143 size=rank,
144 )
145 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700146
147 def setTargetShape(self, shape):
148 self.targetted_shape = shape
149
150 def randInt(self, low=0, high=256):
151 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
152
153 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100154 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100155 return np.float32(
156 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
157 )
James Ward8b390432022-08-12 20:48:56 +0100158 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100159 return np.float16(
160 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
161 )
James Ward24dbc422022-10-19 12:20:31 +0100162 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100163 rand_f32 = np.float32(
164 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
165 )
James Ward24dbc422022-10-19 12:20:31 +0100166 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700167 elif dtype == DType.BOOL:
168 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700169 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700171 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700172 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100173 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 elif dtype == DType.INT16:
175 low, high = (-32768, 32768)
176 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800177 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800179 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700180 # Special size
181 return np.int64(self.rng.integers(low, high, size=1))[0]
182 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800183 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184
185 return np.int32(self.rng.integers(low, high, size=1))[0]
186
187 def shapeStr(self, shape):
188
189 sStr = []
190 # Convert to strings
191 for i in shape:
192 sStr.append(str(i))
193
Kevin Cheng550ccc52021-03-03 11:21:43 -0800194 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700195
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100196 def typeStr(self, dtype):
197 if isinstance(dtype, list) or isinstance(dtype, tuple):
198 assert len(dtype) >= 2
199 strs = [self.typeStr(t) for t in dtype]
200 # Limit types to the first 2 as the 3rd is the accumulator
201 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700202 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100203 if dtype in DTYPE_ATTRIBUTES:
204 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100206 raise Exception(
207 "Unknown dtype, cannot convert to string: {}".format(dtype)
208 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700209
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100210 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100211 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100212 if dtype in DTYPE_ATTRIBUTES:
213 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100215 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700216
Luke Hutton57287132023-02-06 14:54:18 +0000217 def constrictBatchSize(self, shape):
218 # Limit the batch size unless an explicit target shape set
219 if self.args.max_batch_size and not self.args.target_shapes:
220 shape[0] = min(shape[0], self.args.max_batch_size)
221 return shape
222
James Ward30124a82023-02-02 14:56:33 +0000223 def makeDimension(self):
224 return self.randInt(
225 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
226 )
227
Eric Kunzee5e26762020-10-13 16:11:07 -0700228 # Argument generators
229 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
230 # Where the string descriptor is used to generate the test name and
231 # The build_fcn_arg_list is expanded and passed to the operator test
232 # build function
233
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100234 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
235 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
236
Matthew Haddon848efb42021-09-09 12:30:53 +0100237 # build_placeholder returns an int, ABS/other ops does not
238 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000239 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100240 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000241 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000242 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100243 return result_tens
244
245 # Ensure new output type has correct qinfo
246 if error_name == ErrorIf.WrongOutputType:
247 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000248 qinfo = [
249 TosaQuantGen.getZeroPoint(self, a.dtype),
250 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
251 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100252
253 # Invalidate Input/Output list for error if checks.
254 input_list = [a.name]
255 output_list = [result_tens.name]
256 pCount, cCount = op["operands"]
257 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000258 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
259 self, error_name, input_list, output_list
260 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100261
Les Bell729b0352021-11-24 10:28:21 +0000262 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100263 self.ser,
264 validator_fcns,
265 error_name,
266 op=op,
267 input_dtype=a.dtype,
268 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000269 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000270 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100271 input_list=input_list,
272 output_list=output_list,
273 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000274 ):
275 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100276
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000277 attr = None
278 if op["op"] == Op.NEGATE:
279 attr = ts.TosaSerializerAttribute()
280 attr.NegateAttribute(qinfo[0], qinfo[1])
281
282 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700283 return result_tens
284
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100285 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000286 result_tens = OutputShaper.binaryBroadcastOp(
287 self.ser, self.rng, a, b, error_name
288 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100289
290 # Invalidate Input/Output list for error if checks.
291 input_list = [a.name, b.name]
292 output_list = [result_tens.name]
293 pCount, cCount = op["operands"]
294 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000295 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
296 self, error_name, input_list, output_list
297 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100298
Les Bell729b0352021-11-24 10:28:21 +0000299 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100300 self.ser,
301 validator_fcns,
302 error_name,
303 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000304 input1=a,
305 input2=b,
306 input_dtype=a.dtype,
307 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000308 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100309 input_list=input_list,
310 output_list=output_list,
311 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000312 ):
313 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100314
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000315 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700316 return result_tens
317
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100318 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700319 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000320 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700321 return result_tens
322
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000323 def build_arithmetic_right_shift(
324 self, op, a, b, round, validator_fcns=None, error_name=None
325 ):
326 result_tens = OutputShaper.binaryBroadcastOp(
327 self.ser, self.rng, a, b, error_name
328 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100329
330 # Invalidate Input/Output list for error if checks.
331 input_list = [a.name, b.name]
332 output_list = [result_tens.name]
333 pCount, cCount = op["operands"]
334 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000335 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
336 self, error_name, input_list, output_list
337 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100338
Les Bell729b0352021-11-24 10:28:21 +0000339 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100340 self.ser,
341 validator_fcns,
342 error_name,
343 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000344 input1=a,
345 input2=b,
346 input_dtype=a.dtype,
347 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000348 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100349 input_list=input_list,
350 output_list=output_list,
351 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000352 ):
353 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800354
355 attr = ts.TosaSerializerAttribute()
356 attr.ArithmeticRightShiftAttribute(round)
357
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000358 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800359 return result_tens
360
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100361 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000362 result_tens = OutputShaper.binaryBroadcastOp(
363 self.ser, self.rng, a, b, error_name
364 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700365
366 # Special for multiply:
367 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100368 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700369 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100370 if error_name == ErrorIf.WrongOutputType:
371 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
372 outputDType = self.rng.choice(all_dtypes)
373 result_tens.setDtype(outputDType)
374
375 # Invalidate Input/Output list for error if checks.
376 input_list = [a.name, b.name]
377 output_list = [result_tens.name]
378 pCount, cCount = op["operands"]
379 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000380 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
381 self, error_name, input_list, output_list
382 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100383
Les Bell729b0352021-11-24 10:28:21 +0000384 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100385 self.ser,
386 validator_fcns,
387 error_name,
388 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000389 input1=a,
390 input2=b,
391 input_dtype=a.dtype,
392 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000393 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100394 input_list=input_list,
395 output_list=output_list,
396 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000397 ):
398 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700399
Kevin Chengaee1fac2020-11-11 13:54:06 -0800400 attr = ts.TosaSerializerAttribute()
401 attr.MulAttribute(shift)
402
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000403 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700404 return result_tens
405
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100406 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
407 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700408
Kevin Chengfe392ce2021-10-18 21:51:55 +0000409 attr = ts.TosaSerializerAttribute()
410 attr.TableAttribute(table)
411
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100412 # Invalidate Input/Output list for error if checks.
413 input_list = [a.name]
414 output_list = [result_tens.name]
415 pCount, cCount = op["operands"]
416 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000417 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
418 self, error_name, input_list, output_list
419 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100420
Les Bell729b0352021-11-24 10:28:21 +0000421 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100422 self.ser,
423 validator_fcns,
424 error_name,
425 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000426 input_shape=a.shape,
427 input_dtype=a.dtype,
428 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000429 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100430 input_list=input_list,
431 output_list=output_list,
432 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000433 ):
434 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100435
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700437
438 return result_tens
439
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100440 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
441 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
442
443 # Invalidate Input/Output list for error if checks.
444 input_list = [cond.name, a.name, b.name]
445 output_list = [result_tens.name]
446 pCount, cCount = op["operands"]
447 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000448 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
449 self, error_name, input_list, output_list
450 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100451
Les Bell729b0352021-11-24 10:28:21 +0000452 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100453 self.ser,
454 validator_fcns,
455 error_name,
456 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 input1=cond,
458 input2=a,
459 input3=b,
460 input_shape=a.shape,
461 input_dtype=a.dtype,
462 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000463 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100464 input_list=input_list,
465 output_list=output_list,
466 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000467 ):
468 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100469
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000470 self.ser.addOperator(
471 op["op"],
472 input_list,
473 output_list,
474 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700475 return result_tens
476
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100477 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000478 result_tens = OutputShaper.binaryComparisonOp(
479 self.ser, self.rng, a, b, error_name
480 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100481
482 # Invalidate Input/Output list for error if checks.
483 input_list = [a.name, b.name]
484 output_list = [result_tens.name]
485 pCount, cCount = op["operands"]
486 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000487 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
488 self, error_name, input_list, output_list
489 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100490
Les Bell729b0352021-11-24 10:28:21 +0000491 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100492 self.ser,
493 validator_fcns,
494 error_name,
495 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000496 input1=a,
497 input2=b,
498 input_shape=a.shape,
499 input_dtype=a.dtype,
500 output_shape=result_tens.shape,
501 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000502 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100503 input_list=input_list,
504 output_list=output_list,
505 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000506 ):
507 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100508
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000509 self.ser.addOperator(
510 op["op"],
511 input_list,
512 output_list,
513 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700514 return result_tens
515
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100516 def build_argmax(self, op, a, axis, validator_fcns, error_name):
517 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
518
519 # Invalidate Input/Output list for error if checks.
520 input_list = [a.name]
521 output_list = [result_tens.name]
522 pCount, cCount = op["operands"]
523 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000524 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
525 self, error_name, input_list, output_list
526 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527
Les Bell729b0352021-11-24 10:28:21 +0000528 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100529 self.ser,
530 validator_fcns,
531 error_name,
532 op=op,
533 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000534 input_shape=a.shape,
535 input_dtype=a.dtype,
536 output_shape=result_tens.shape,
537 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000538 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100539 input_list=input_list,
540 output_list=output_list,
541 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000542 ):
543 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700544
545 attr = ts.TosaSerializerAttribute()
546 attr.AxisAttribute(axis)
547
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000548 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700549 return result_tens
550
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000551 def build_pool2d(
552 self,
553 op,
554 input,
James Ward8b390432022-08-12 20:48:56 +0100555 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000556 stride,
557 pad,
558 kernel,
559 validator_fcns=None,
560 error_name=None,
561 qinfo=None,
562 ):
563 result_tens = OutputShaper.pool2dOp(
564 self.ser, self.rng, input, kernel, stride, pad, error_name
565 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100566
567 # Ensure new output type has correct qinfo
568 if error_name == ErrorIf.WrongInputType:
569 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000570 qinfo = [
571 TosaQuantGen.getZeroPoint(self, input.dtype),
572 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
573 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100574
575 # Invalidate Input/Output list for error if checks.
576 input_list = [input.name]
577 output_list = [result_tens.name]
578 pCount, cCount = op["operands"]
579 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000580 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
581 self, error_name, input_list, output_list
582 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100583
Les Bell729b0352021-11-24 10:28:21 +0000584 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100585 self.ser,
586 validator_fcns,
587 error_name,
588 op=op,
589 input_shape=input.shape,
590 input_dtype=input.dtype,
591 output_shape=result_tens.shape,
592 output_dtype=result_tens.dtype,
593 kernel=kernel,
594 stride=stride,
595 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000596 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000597 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100598 input_list=input_list,
599 output_list=output_list,
600 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000601 ):
602 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700603
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000604 if qinfo is None:
605 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700606
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000607 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100608 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000609
610 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700611 return result_tens
612
James Ward8b390432022-08-12 20:48:56 +0100613 def build_maxpool2d(
614 self,
615 op,
616 input,
617 stride,
618 pad,
619 kernel,
620 validator_fcns=None,
621 error_name=None,
622 qinfo=None,
623 ):
624 # Same as build_pool2d but manually sets accum_dtype value
625 # (maxpool has no accum_dtype)
626 return self.build_pool2d(
627 op,
628 input,
629 DType.UNKNOWN,
630 stride,
631 pad,
632 kernel,
633 validator_fcns,
634 error_name,
635 qinfo,
636 )
637
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000638 def build_conv2d(
639 self,
640 op,
641 ifm,
642 filter,
643 bias,
James Ward8b390432022-08-12 20:48:56 +0100644 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000645 strides,
646 padding,
647 dilations,
648 validator_fcns=None,
649 error_name=None,
650 qinfo=None,
651 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800652 assert len(padding) == 4
653 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100654 self.ser,
655 self.rng,
656 ifm,
657 filter,
658 accum_dtype,
659 strides,
660 padding,
661 dilations,
662 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000663 )
664
665 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000666 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
667 DType.INT8,
668 DType.UINT8,
669 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000670 qinfo = [
671 TosaQuantGen.getZeroPoint(self, ifm.dtype),
672 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
673 ]
Les Bell0e027d42021-11-09 14:42:14 +0000674
675 # Invalidate Input/Output list for error_if checks.
676 input_list = [ifm.name, filter.name, bias.name]
677 output_list = [result_tens.name]
678 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000679 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
680 self, error_name, input_list, output_list
681 )
Les Bell0e027d42021-11-09 14:42:14 +0000682
Les Bell729b0352021-11-24 10:28:21 +0000683 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000684 self.ser,
685 validator_fcns,
686 error_name,
687 op=op,
688 input_dtype=ifm.dtype,
689 weight_dtype=filter.dtype,
690 output_dtype=result_tens.dtype,
691 qinfo=qinfo,
692 input_list=input_list,
693 num_operands=num_operands,
694 output_list=output_list,
695 pad=padding,
696 stride=strides,
697 dilation=dilations,
698 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100699 weight_shape=filter.shape,
700 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000701 ):
702 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700703
704 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000705 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700706
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000707 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700708 return result_tens
709
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000710 def build_conv3d(
711 self,
712 op,
713 ifm,
714 filter,
715 bias,
James Ward8b390432022-08-12 20:48:56 +0100716 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000717 strides,
718 padding,
719 dilations,
720 validator_fcns=None,
721 error_name=None,
722 qinfo=None,
723 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700724 assert len(padding) == 6
725 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100726 self.ser,
727 self.rng,
728 ifm,
729 filter,
730 accum_dtype,
731 strides,
732 padding,
733 dilations,
734 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000735 )
736
737 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000738 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
739 DType.INT8,
740 DType.UINT8,
741 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000742 qinfo = [
743 TosaQuantGen.getZeroPoint(self, ifm.dtype),
744 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
745 ]
Les Bell0e027d42021-11-09 14:42:14 +0000746
747 # Invalidate Input/Output list for error_if checks.
748 input_list = [ifm.name, filter.name, bias.name]
749 output_list = [result_tens.name]
750 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000751 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
752 self, error_name, input_list, output_list
753 )
Les Bell0e027d42021-11-09 14:42:14 +0000754
Les Bell729b0352021-11-24 10:28:21 +0000755 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000756 self.ser,
757 validator_fcns,
758 error_name,
759 op=op,
760 input_dtype=ifm.dtype,
761 weight_dtype=filter.dtype,
762 output_dtype=result_tens.dtype,
763 qinfo=qinfo,
764 input_list=input_list,
765 num_operands=num_operands,
766 output_list=output_list,
767 pad=padding,
768 stride=strides,
769 dilation=dilations,
770 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100771 weight_shape=filter.shape,
772 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000773 ):
774 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700775
776 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000777 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700778
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000779 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700780 return result_tens
781
Kevin Cheng550ccc52021-03-03 11:21:43 -0800782 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000783 self,
784 op,
785 ifm,
786 filter,
787 bias,
James Ward8b390432022-08-12 20:48:56 +0100788 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000789 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700790 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000791 output_shape,
792 validator_fcns=None,
793 error_name=None,
794 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800795 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700796 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000797 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100798 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000799 )
Les Bell0e027d42021-11-09 14:42:14 +0000800
801 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000802 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
803 DType.INT8,
804 DType.UINT8,
805 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000806 qinfo = [
807 TosaQuantGen.getZeroPoint(self, ifm.dtype),
808 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
809 ]
Les Bell0e027d42021-11-09 14:42:14 +0000810
811 # Invalidate Input/Output list for error_if checks.
812 input_list = [ifm.name, filter.name, bias.name]
813 output_list = [result_tens.name]
814 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000815 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
816 self, error_name, input_list, output_list
817 )
Les Bell0e027d42021-11-09 14:42:14 +0000818
Les Bell729b0352021-11-24 10:28:21 +0000819 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000820 self.ser,
821 validator_fcns,
822 error_name,
823 op=op,
824 input_dtype=ifm.dtype,
825 weight_dtype=filter.dtype,
826 output_dtype=result_tens.dtype,
827 qinfo=qinfo,
828 input_list=input_list,
829 num_operands=num_operands,
830 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700831 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000832 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000833 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100834 weight_shape=filter.shape,
835 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000836 ):
837 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700838
839 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000840 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700841
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000842 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700843 return result_tens
844
Kevin Cheng550ccc52021-03-03 11:21:43 -0800845 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000846 self,
847 op,
848 ifm,
849 filter,
850 bias,
James Ward8b390432022-08-12 20:48:56 +0100851 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000852 strides,
853 padding,
854 dilations,
855 validator_fcns=None,
856 error_name=None,
857 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800858 ):
859 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100860 self.ser,
861 self.rng,
862 ifm,
863 filter,
864 accum_dtype,
865 strides,
866 padding,
867 dilations,
868 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000869 )
870
871 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000872 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
873 DType.INT8,
874 DType.UINT8,
875 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000876 qinfo = [
877 TosaQuantGen.getZeroPoint(self, ifm.dtype),
878 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
879 ]
Les Bell0e027d42021-11-09 14:42:14 +0000880
881 # Invalidate Input/Output list for error_if checks.
882 input_list = [ifm.name, filter.name, bias.name]
883 output_list = [result_tens.name]
884 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000885 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
886 self, error_name, input_list, output_list
887 )
Les Bell0e027d42021-11-09 14:42:14 +0000888
Les Bell729b0352021-11-24 10:28:21 +0000889 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000890 self.ser,
891 validator_fcns,
892 error_name,
893 op=op,
894 input_dtype=ifm.dtype,
895 weight_dtype=filter.dtype,
896 output_dtype=result_tens.dtype,
897 qinfo=qinfo,
898 input_list=input_list,
899 num_operands=num_operands,
900 output_list=output_list,
901 pad=padding,
902 stride=strides,
903 dilation=dilations,
904 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100905 weight_shape=filter.shape,
906 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000907 ):
908 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700909
910 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000911 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700912
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000913 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700914 return result_tens
915
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100917 self,
918 op,
919 ifm,
920 filter,
921 bias,
922 accum_dtype,
923 validator_fcns=None,
924 error_name=None,
925 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000926 ):
927 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100928 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000929 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100930
931 # Invalidate Input/Output list for error if checks.
932 input_list = [ifm.name, filter.name, bias.name]
933 output_list = [result_tens.name]
934 pCount, cCount = op["operands"]
935 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000936 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
937 self, error_name, input_list, output_list
938 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100939
Les Bell729b0352021-11-24 10:28:21 +0000940 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100941 self.ser,
942 validator_fcns,
943 error_name,
944 op=op,
945 input_shape=ifm.shape,
946 input_dtype=ifm.dtype,
947 weight_dtype=filter.dtype,
948 output_shape=result_tens.shape,
949 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000950 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000951 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100952 input_list=input_list,
953 output_list=output_list,
954 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100955 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000956 ):
957 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700958
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000959 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000960 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000961
962 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700963 return result_tens
964
James Ward8b390432022-08-12 20:48:56 +0100965 def build_matmul(
966 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
967 ):
968 result_tens = OutputShaper.matmulOp(
969 self.ser, self.rng, a, b, accum_dtype, error_name
970 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100971
972 # Invalidate Input/Output list for error if checks.
973 input_list = [a.name, b.name]
974 output_list = [result_tens.name]
975 pCount, cCount = op["operands"]
976 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000977 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
978 self, error_name, input_list, output_list
979 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100980
Les Bell729b0352021-11-24 10:28:21 +0000981 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100982 self.ser,
983 validator_fcns,
984 error_name,
985 op=op,
986 input_shape=a.shape,
987 input_dtype=a.dtype,
988 input2_shape=b.shape,
989 input2_dtype=b.dtype,
990 output_shape=result_tens.shape,
991 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000992 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000993 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100994 input_list=input_list,
995 output_list=output_list,
996 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100997 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000998 ):
999 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001000
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001001 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001002 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001003
1004 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001005 return result_tens
1006
Matthew Haddond6ce7252021-09-29 15:35:44 +01001007 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1008 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1009
1010 # Invalidate Input/Output list for error if checks.
1011 input_list = [a.name]
1012 output_list = [result_tens.name]
1013 pCount, cCount = op["operands"]
1014 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001015 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1016 self, error_name, input_list, output_list
1017 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001018
Les Bell729b0352021-11-24 10:28:21 +00001019 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001020 self.ser,
1021 validator_fcns,
1022 error_name,
1023 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001024 axis=axis,
1025 input_shape=a.shape,
1026 output_shape=result_tens.shape,
1027 input_dtype=a.dtype,
1028 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001029 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001030 input_list=input_list,
1031 output_list=output_list,
1032 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001033 ):
1034 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001035
1036 attr = ts.TosaSerializerAttribute()
1037 attr.AxisAttribute(axis)
1038
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001039 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001040 return result_tens
1041
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001042 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1043 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001044
Jeremy Johnson18e26662021-07-22 16:15:29 +01001045 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001047 if error_name == ErrorIf.MaxSmallerMin:
1048 # Make sure the numbers are different to invoke this error
1049 while v[0] == v[1]:
1050 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1051 max_val = min(v)
1052 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001053 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001054 max_val = max(v)
1055 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001056
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001057 # Invalidate Input/Output list for error if checks.
1058 input_list = [a.name]
1059 output_list = [result_tens.name]
1060 pCount, cCount = op["operands"]
1061 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001062 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1063 self, error_name, input_list, output_list
1064 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001065
Les Bell729b0352021-11-24 10:28:21 +00001066 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001067 self.ser,
1068 validator_fcns,
1069 error_name,
1070 op=op,
1071 max_val=max_val,
1072 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001073 input_shape=a.shape,
1074 output_shape=result_tens.shape,
1075 input_dtype=a.dtype,
1076 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001077 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001078 input_list=input_list,
1079 output_list=output_list,
1080 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001081 ):
1082 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001083
1084 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001085 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1086 if a.dtype == DType.FP16:
1087 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1088 min_val = min_val.astype(np.float32)
1089 max_val = max_val.astype(np.float32)
1090
1091 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001092 else:
James Ward34071252022-12-07 15:48:47 +00001093 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001094
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001095 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001096 return result_tens
1097
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001098 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1099 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001100 attr = ts.TosaSerializerAttribute()
1101
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001102 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001103
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001104 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001105 return result_tens
1106
1107 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001108 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1109 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001110
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001111 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001112 return result_tens
1113
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001114 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1115 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1116
1117 # Invalidate Input/Output list for error if checks.
1118 input_list = [a.name]
1119 output_list = [result_tens.name]
1120 pCount, cCount = op["operands"]
1121 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001122 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1123 self, error_name, input_list, output_list
1124 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001125
Les Bell729b0352021-11-24 10:28:21 +00001126 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001127 self.ser,
1128 validator_fcns,
1129 error_name,
1130 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001131 input_shape=a.shape,
1132 output_shape=result_tens.shape,
1133 input_dtype=a.dtype,
1134 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001135 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001136 input_list=input_list,
1137 output_list=output_list,
1138 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001139 ):
1140 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001141
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001142 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001143 return result_tens
1144
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001145 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1146 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1147
1148 # Invalidate Input/Output list for error if checks.
1149 input_list = [a.name]
1150 output_list = [result_tens.name]
1151 pCount, cCount = op["operands"]
1152 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001153 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1154 self, error_name, input_list, output_list
1155 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001156
Les Bell729b0352021-11-24 10:28:21 +00001157 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001158 self.ser,
1159 validator_fcns,
1160 error_name,
1161 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001162 input_shape=a.shape,
1163 output_shape=result_tens.shape,
1164 input_dtype=a.dtype,
1165 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001166 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001167 input_list=input_list,
1168 output_list=output_list,
1169 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001170 ):
1171 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001172
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001173 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001174 return result_tens
1175
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001176 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1177 if error_name != ErrorIf.WrongInputType:
1178 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001179
1180 # To store variable length list of input tensors we need to store axis along with it
1181 axis = a[-1]
1182 a = a[:-1]
1183
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001184 result_tens = OutputShaper.concatOp(
1185 self.ser, self.rng, axis, *a, error_name=error_name
1186 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001187
Matthew Haddon818ab902021-07-27 09:12:49 +01001188 input_tensor_names = []
1189 for tensor in a:
1190 input_tensor_names.append(tensor.name)
1191
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001192 # Invalidate Input/Output list for error if checks.
1193 input_list = input_tensor_names
1194 output_list = [result_tens.name]
1195 pCount, cCount = op["operands"]
1196 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1198 self, error_name, input_list, output_list
1199 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001200
Les Bell729b0352021-11-24 10:28:21 +00001201 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001202 self.ser,
1203 validator_fcns,
1204 error_name,
1205 op=op,
1206 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001207 input_shape=a[0].shape,
1208 output_shape=result_tens.shape,
1209 input_dtype=a[0].dtype,
1210 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001211 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001212 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001213 input_list=input_list,
1214 output_list=output_list,
1215 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001216 ):
1217 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001218
1219 attr = ts.TosaSerializerAttribute()
1220 attr.AxisAttribute(axis)
1221
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001222 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001223 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001224
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001225 def build_pad(
1226 self,
1227 op,
1228 a,
1229 padding,
1230 pad_const_int,
1231 pad_const_float,
1232 validator_fcns=None,
1233 error_name=None,
1234 qinfo=None,
1235 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001236 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001237
Kevin Chengfe392ce2021-10-18 21:51:55 +00001238 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001239 attr.PadAttribute(
1240 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1241 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001242
Matthew Haddone807aae2021-10-11 18:12:58 +01001243 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001244 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001245 output_list = [result_tens.name]
1246 pCount, cCount = op["operands"]
1247 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001248 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1249 self, error_name, input_list, output_list
1250 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001251
Les Bell729b0352021-11-24 10:28:21 +00001252 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001253 self.ser,
1254 validator_fcns,
1255 error_name,
1256 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 input_shape=a.shape,
1258 output_shape=result_tens.shape,
1259 input_dtype=a.dtype,
1260 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001261 pad=padding,
1262 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001263 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001264 input_list=input_list,
1265 output_list=output_list,
1266 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001267 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001268 ):
1269 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001270
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001271 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001272 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001273
Matthew Haddone807aae2021-10-11 18:12:58 +01001274 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001275 result_tens = OutputShaper.reshapeOp(
1276 self.ser, self.rng, a, newShape, error_name
1277 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001278
1279 # Invalidate Input/Output list for error if checks.
1280 input_list = [a.name]
1281 output_list = [result_tens.name]
1282 pCount, cCount = op["operands"]
1283 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001284 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1285 self, error_name, input_list, output_list
1286 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001287
Les Bell729b0352021-11-24 10:28:21 +00001288 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001289 self.ser,
1290 validator_fcns,
1291 error_name,
1292 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001293 input_shape=a.shape,
1294 output_shape=result_tens.shape,
1295 input_dtype=a.dtype,
1296 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001297 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001298 input_list=input_list,
1299 output_list=output_list,
1300 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001301 ):
1302 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001303
1304 attr = ts.TosaSerializerAttribute()
1305 attr.ReshapeAttribute(newShape)
1306
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001307 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001308 return result_tens
1309
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001310 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1311 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1312
1313 # Invalidate Input/Output list for error if checks.
1314 input_list = [a.name]
1315 output_list = [result_tens.name]
1316 pCount, cCount = op["operands"]
1317 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001318 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1319 self, error_name, input_list, output_list
1320 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001321
Les Bell729b0352021-11-24 10:28:21 +00001322 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001323 self.ser,
1324 validator_fcns,
1325 error_name,
1326 op=op,
1327 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001328 input_shape=a.shape,
1329 output_shape=result_tens.shape,
1330 input_dtype=a.dtype,
1331 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001332 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001333 input_list=input_list,
1334 output_list=output_list,
1335 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001336 ):
1337 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001338
1339 attr = ts.TosaSerializerAttribute()
1340 attr.AxisAttribute(axis)
1341
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001342 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001343 return result_tens
1344
Matthew Haddone807aae2021-10-11 18:12:58 +01001345 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1346 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001347
Kevin Chengfe392ce2021-10-18 21:51:55 +00001348 attr = ts.TosaSerializerAttribute()
1349 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001350
Matthew Haddone807aae2021-10-11 18:12:58 +01001351 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001352 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001353 output_list = [result_tens.name]
1354 pCount, cCount = op["operands"]
1355 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001356 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1357 self, error_name, input_list, output_list
1358 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001359
Les Bell729b0352021-11-24 10:28:21 +00001360 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001361 self.ser,
1362 validator_fcns,
1363 error_name,
1364 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001365 input_shape=a.shape,
1366 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001367 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001368 input_dtype=a.dtype,
1369 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001370 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001371 input_list=input_list,
1372 output_list=output_list,
1373 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001374 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001375 ):
1376 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001377
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001378 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001379 return result_tens
1380
Matthew Haddone807aae2021-10-11 18:12:58 +01001381 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001382 result_tens = OutputShaper.sliceOp(
1383 self.ser, self.rng, a, start, size, error_name
1384 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001385
1386 # Invalidate Input/Output list for error if checks.
1387 input_list = [a.name]
1388 output_list = [result_tens.name]
1389 pCount, cCount = op["operands"]
1390 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001391 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1392 self, error_name, input_list, output_list
1393 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001394
Les Bell729b0352021-11-24 10:28:21 +00001395 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001396 self.ser,
1397 validator_fcns,
1398 error_name,
1399 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001400 input_shape=a.shape,
1401 output_shape=result_tens.shape,
1402 input_dtype=a.dtype,
1403 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001404 start=start,
1405 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001406 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001407 input_list=input_list,
1408 output_list=output_list,
1409 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001410 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001411 ):
1412 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001413
1414 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001415 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001416
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001418 return result_tens
1419
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1421 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1422
1423 # Invalidate Input/Output list for error if checks.
1424 input_list = [a.name]
1425 output_list = [result_tens.name]
1426 pCount, cCount = op["operands"]
1427 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001428 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1429 self, error_name, input_list, output_list
1430 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001431
Les Bell729b0352021-11-24 10:28:21 +00001432 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001433 self.ser,
1434 validator_fcns,
1435 error_name,
1436 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001437 input_shape=a.shape,
1438 output_shape=result_tens.shape,
1439 input_dtype=a.dtype,
1440 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001441 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001442 input_list=input_list,
1443 output_list=output_list,
1444 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001445 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001446 ):
1447 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001448
1449 attr = ts.TosaSerializerAttribute()
1450 attr.TileAttribute(multiples)
1451
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001453 return result_tens
1454
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001455 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001456
1457 # Create a new indicies tensor
1458 # here with data that doesn't exceed the dimensions of the values tensor
1459
Kevin Cheng550ccc52021-03-03 11:21:43 -08001460 K = values.shape[1] # K
1461 W = self.randInt(
1462 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1463 ) # W
1464 indicies_arr = np.int32(
1465 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1466 ) # (N, W)
1467 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001468
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001469 result_tens = OutputShaper.gatherOp(
1470 self.ser, self.rng, values, indicies, error_name
1471 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001472
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001473 # Invalidate Input/Output list for error if checks.
1474 input_list = [values.name, indicies.name]
1475 output_list = [result_tens.name]
1476 pCount, cCount = op["operands"]
1477 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001478 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1479 self, error_name, input_list, output_list
1480 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001481
Les Bell729b0352021-11-24 10:28:21 +00001482 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001483 self.ser,
1484 validator_fcns,
1485 error_name,
1486 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001487 input_shape=values.shape,
1488 output_shape=result_tens.shape,
1489 input_dtype=values.dtype,
1490 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001491 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001492 input_list=input_list,
1493 output_list=output_list,
1494 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001495 ):
1496 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001497
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001498 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001499
1500 return result_tens
1501
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001502 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001503
1504 # Create a new indicies tensor
1505 # here with data that doesn't exceed the dimensions of the values_in tensor
1506
Kevin Cheng550ccc52021-03-03 11:21:43 -08001507 K = values_in.shape[1] # K
1508 W = input.shape[1] # W
1509 indicies_arr = np.int32(
1510 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1511 ) # (N, W)
1512 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001513
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001514 result_tens = OutputShaper.scatterOp(
1515 self.ser, self.rng, values_in, indicies, input, error_name
1516 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001517
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001518 # Invalidate Input/Output list for error if checks.
1519 input_list = [values_in.name, indicies.name, input.name]
1520 output_list = [result_tens.name]
1521 pCount, cCount = op["operands"]
1522 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001523 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1524 self, error_name, input_list, output_list
1525 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001526
Les Bell729b0352021-11-24 10:28:21 +00001527 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001528 self.ser,
1529 validator_fcns,
1530 error_name,
1531 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001532 input_shape=values_in.shape,
1533 output_shape=result_tens.shape,
1534 input_dtype=values_in.dtype,
1535 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001536 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001537 input_list=input_list,
1538 output_list=output_list,
1539 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001540 ):
1541 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001542
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001543 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001544
Kevin Cheng77d0f762020-11-24 10:26:32 -08001545 return result_tens
1546
Kevin Cheng550ccc52021-03-03 11:21:43 -08001547 def build_resize(
1548 self,
1549 op,
1550 input,
1551 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001552 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001553 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001554 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001555 input_dtype,
1556 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001557 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001558 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001559 ):
1560 result_tens = OutputShaper.resizeOp(
1561 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001562 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001563 input,
1564 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001565 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001566 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001567 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001568 input_dtype,
1569 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001570 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001571 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001572
Matthew Haddon848efb42021-09-09 12:30:53 +01001573 # Invalidate Input/Output list for error if checks.
1574 input_list = [input.name]
1575 output_list = [result_tens.name]
1576 pCount, cCount = op["operands"]
1577 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001578 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1579 self, error_name, input_list, output_list
1580 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001581
Les Bell729b0352021-11-24 10:28:21 +00001582 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001583 self.ser,
1584 validator_fcns,
1585 error_name,
1586 op=op,
1587 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001588 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001589 input_dtype=input_dtype,
1590 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001591 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001592 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001593 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001594 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001595 input_list=input_list,
1596 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001597 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001598 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001599 ):
1600 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001601
Eric Kunzee5e26762020-10-13 16:11:07 -07001602 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001603
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001604 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001605
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001606 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001607 return result_tens
1608
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001609 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1610 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1611 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001612 self.ser.addOperator(
1613 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1614 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001615 return result_tens
1616
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001617 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001618 self.ser.addOutputTensor(val)
1619 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
1621 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001622 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001623 result_tens = OutputShaper.typeConversionOp(
1624 self.ser, self.rng, val, out_dtype, error_name
1625 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001626
1627 # Invalidate Input/Output list for error if checks.
1628 input_list = [val.name]
1629 output_list = [result_tens.name]
1630 pCount, cCount = op["operands"]
1631 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001632 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1633 self, error_name, input_list, output_list
1634 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001635
Les Bell729b0352021-11-24 10:28:21 +00001636 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001637 self.ser,
1638 validator_fcns,
1639 error_name,
1640 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001641 input_shape=val.shape,
1642 output_shape=result_tens.shape,
1643 input_dtype=val.dtype,
1644 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001645 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001646 input_list=input_list,
1647 output_list=output_list,
1648 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001649 ):
1650 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001651
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001652 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001653 return result_tens
1654
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001655 def build_rescale(
1656 self,
1657 op,
1658 val,
1659 out_dtype,
1660 scale32,
1661 double_round,
1662 per_channel,
1663 validator_fcns,
1664 error_name,
1665 ):
1666 result_tens = OutputShaper.typeConversionOp(
1667 self.ser, self.rng, val, out_dtype, error_name
1668 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001669
1670 if per_channel:
1671 nc = val.shape[-1]
1672 else:
1673 nc = 1
1674
1675 in_type_width = self.typeWidth(val.dtype)
1676 out_type_width = self.typeWidth(out_dtype)
1677
Kevin Cheng3a478572021-01-22 17:21:02 -08001678 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001679 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001680 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001681 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001682 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001683 in_type_width += 1
1684 elif error_name in [
1685 ErrorIf.InputZeroPointNotZero,
1686 ErrorIf.U16InputZeroPointNotValid,
1687 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001688 input_zp = self.randInt(-128, 128)
1689 if input_zp == 0:
1690 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001691 in_type_width += 1
1692 elif val.dtype == DType.UINT16:
1693 # Must come after ErrorIf.U16InputZeroPointNotValid check
1694 input_zp = self.rng.choice([0, 32768])
1695 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001696 else:
1697 input_zp = 0
1698
Kevin Cheng3a478572021-01-22 17:21:02 -08001699 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001700 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001701 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001702 elif out_dtype == DType.UINT8:
1703 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001704 out_type_width += 1
1705 elif error_name in [
1706 ErrorIf.OutputZeroPointNotZero,
1707 ErrorIf.U16OutputZeroPointNotValid,
1708 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001709 output_zp = self.randInt(-128, 128)
1710 if output_zp == 0:
1711 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001712 out_type_width += 1
1713 elif out_dtype == DType.UINT16:
1714 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1715 output_zp = self.rng.choice([0, 32768])
1716 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001717 else:
1718 output_zp = 0
1719
1720 # Calculate scale based on:
1721 # scale = a *(2^output_width)/(2^input_width))
1722
1723 a = np.float32(self.rng.random(size=[nc]))
1724 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1725
1726 if scale32:
1727 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001728 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001729 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1730 else:
1731 # Cap the scaling at 2^15 - 1 for scale16
1732 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1733
Kevin Cheng550ccc52021-03-03 11:21:43 -08001734 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001735
1736 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1737 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001738 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1739 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001740
1741 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001742 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1743 scale_arr[i], scale32
1744 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001745 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1746 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001747
Kevin Cheng550ccc52021-03-03 11:21:43 -08001748 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001749 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001750 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001751 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001752 assert val.placeholderFilename
1753 values = np.load(
1754 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1755 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001756 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1757 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1758 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1759 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001760 if not np.all(np.array_equal(values, val_adj)):
1761 # Values changed so overwrite file with new values
1762 np.save(
1763 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1764 val_adj,
1765 False,
1766 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001767
Matthew Haddonc2025212021-10-08 21:21:05 +01001768 # Invalidate Input/Output list for error if checks.
1769 input_list = [val.name]
1770 output_list = [result_tens.name]
1771 pCount, cCount = op["operands"]
1772 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1774 self, error_name, input_list, output_list
1775 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001776
1777 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001778 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001779 self.ser,
1780 validator_fcns,
1781 error_name,
1782 op=op,
1783 input_dtype=val.dtype,
1784 output_dtype=out_dtype,
1785 input_shape=val.shape,
1786 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001787 scale32=scale32,
1788 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001789 input_list=input_list,
1790 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001791 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001792 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001793 ):
1794 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001795
Eric Kunzee5e26762020-10-13 16:11:07 -07001796 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001797 attr.RescaleAttribute(
1798 input_zp,
1799 output_zp,
1800 multiplier_arr,
1801 shift_arr,
1802 scale32,
1803 double_round,
1804 per_channel,
1805 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001806
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001807 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001808 return result_tens
1809
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001810 def _get_condition_tensor(self, op, cond, error_name):
1811 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1812 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1813 else:
1814 cond_type = DType.BOOL
1815 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1816 choice = self.rng.choice([1, 2])
1817 if choice == 1:
1818 cond_shape = [2]
1819 else:
1820 cond_shape = [1, 2]
1821 else:
1822 # Must be of size 1 (rank 0)
1823 cond_shape = []
1824 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1825 return cond_tens
1826
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 def build_cond_if_const(
1828 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1829 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001830 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001831 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001832 # and fill them with const nodes for the body.
1833
1834 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001835 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001836
1837 # Make then/else tensors
1838 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001839
1840 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001841 if error_name in [
1842 ErrorIf.CondIfOutputListThenGraphMismatch,
1843 ErrorIf.CondIfOutputListElseGraphMismatch,
1844 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001845 incorrect_shape = deepcopy(then_tens.shape)
1846 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001847 incorrect_shape[i] += (
1848 self.rng.choice([-3, -2, 2, 3])
1849 if incorrect_shape[i] > 3
1850 else self.rng.choice([1, 2, 4])
1851 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001852 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1853
Jeremy Johnson18e26662021-07-22 16:15:29 +01001854 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1855 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001856
1857 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001858 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001859
1860 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001861 then_block = "THEN_BLOCK"
1862 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001863 attr = ts.TosaSerializerAttribute()
1864 attr.CondIfAttribute(then_block, else_block)
1865
1866 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001867 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001868
Jerry Ge9e94af82022-10-27 09:57:00 -07001869 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001870 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001871 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1872 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1873 else:
1874 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001875 self.ser.addOutputTensor(then_tens)
1876
Jerry Ge9e94af82022-10-27 09:57:00 -07001877 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001878 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1879 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1880 else:
1881 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001882 self.ser.addOutputTensor(else_tens)
1883
Les Bell729b0352021-11-24 10:28:21 +00001884 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001885 self.ser,
1886 validator_fcns,
1887 error_name,
1888 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001889 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001890 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001891 ):
1892 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001893
Eric Kunzee5e26762020-10-13 16:11:07 -07001894 return result_tens
1895
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001896 def build_cond_if_binary(
1897 self, op, a, b, cond, validator_fcns=None, error_name=None
1898 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001899 # For cond_if with a binary op in the then/else blocks, take a and b and
1900 # alternately add or subtract them based on the condition
1901
1902 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001903 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001904
Kevin Cheng550ccc52021-03-03 11:21:43 -08001905 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001906
1907 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001908 then_block = "THEN_BLOCK"
1909 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001910 attr = ts.TosaSerializerAttribute()
1911 attr.CondIfAttribute(then_block, else_block)
1912
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001913 if error_name in [
1914 ErrorIf.CondIfInputListThenGraphMismatch,
1915 ErrorIf.CondIfInputListElseGraphMismatch,
1916 ErrorIf.CondIfOutputListElseGraphMismatch,
1917 ErrorIf.CondIfOutputListThenGraphMismatch,
1918 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001919 incorrect_shape = a.shape.copy()
1920 for i in range(len(incorrect_shape)):
1921 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1922 incorrect_block_input = deepcopy(a)
1923 incorrect_block_input.shape = incorrect_shape
1924
Eric Kunzee5e26762020-10-13 16:11:07 -07001925 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001926 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001927 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001928 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001929
James Ward24dbc422022-10-19 12:20:31 +01001930 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001931 then_op, else_op = Op.ADD, Op.SUB
1932 elif a.dtype in (DType.INT8, DType.INT16):
1933 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1934 else:
1935 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001936
Les Bell6040b4d2021-10-11 12:50:31 +01001937 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001938 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001939 if (
1940 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1941 and block == then_block
1942 ) or (
1943 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1944 and block == else_block
1945 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001946 self.ser.addInputTensor(incorrect_block_input)
1947 self.ser.addInputTensor(b)
1948 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001949 elif (
1950 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1951 and block == then_block
1952 ) or (
1953 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1954 and block == else_block
1955 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001956 self.ser.addInputTensor(a)
1957 self.ser.addInputTensor(b)
1958 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1959 else:
1960 self.ser.addInputTensor(a)
1961 self.ser.addInputTensor(b)
1962 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001963 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001964
Les Bell729b0352021-11-24 10:28:21 +00001965 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001966 self.ser,
1967 validator_fcns,
1968 error_name,
1969 op=op,
1970 a=a,
1971 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07001972 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001973 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001974 ):
1975 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001976
Eric Kunzee5e26762020-10-13 16:11:07 -07001977 return result_tens
1978
Matthew Haddon630c17c2021-10-14 15:05:41 +01001979 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001980 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001981
Kevin Cheng550ccc52021-03-03 11:21:43 -08001982 cond_block = "COND_BLOCK"
1983 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001984
1985 attr = ts.TosaSerializerAttribute()
1986 attr.WhileLoopAttribute(cond_block, body_block)
1987
1988 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001989 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001990 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001991 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001992
1993 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001994 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1995 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001996 if error_name == ErrorIf.InputListOutputListMismatch:
1997 incorrect_acc = deepcopy(acc)
1998 for i in range(len(incorrect_acc.shape)):
1999 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2000 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2001 else:
2002 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002003
2004 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002005 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002006 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002007 [iter.name, a.name, acc.name],
2008 [iter_out.name, a_out.name, acc_out.name],
2009 attr,
2010 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002011 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002012
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002013 if error_name in [
2014 ErrorIf.InputListCondGraphMismatch,
2015 ErrorIf.InputListBodyGraphInputMismatch,
2016 ErrorIf.InputListBodyGraphOutputMismatch,
2017 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002018 incorrect_iter = deepcopy(iter)
2019 for i in range(len(incorrect_iter.shape)):
2020 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2021 if len(incorrect_iter.shape) == 0:
2022 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2023
2024 incorrect_acc = deepcopy(acc)
2025 for i in range(len(incorrect_acc.shape)):
2026 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2027
Eric Kunzee5e26762020-10-13 16:11:07 -07002028 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002029 self.ser.addBasicBlock(cond_block)
2030
Matthew Haddon630c17c2021-10-14 15:05:41 +01002031 if error_name == ErrorIf.InputListCondGraphMismatch:
2032 self.ser.addInputTensor(incorrect_iter)
2033 self.ser.addInputTensor(a)
2034 self.ser.addInputTensor(incorrect_acc)
2035 else:
2036 self.ser.addInputTensor(iter)
2037 self.ser.addInputTensor(a)
2038 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002039 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002040
2041 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002042 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002043 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002044 cond_type = DType.BOOL
2045 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2046 choice = self.rng.choice([1, 2])
2047 if choice == 1:
2048 cond_shape = [3]
2049 else:
2050 cond_shape = [1, 2]
2051 else:
2052 cond_shape = []
2053 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002054
Kevin Cheng550ccc52021-03-03 11:21:43 -08002055 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002056
2057 # BODY block (input: a, acc, iter, output: a, acc, iter)
2058 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002059 self.ser.addBasicBlock(body_block)
2060
Matthew Haddon630c17c2021-10-14 15:05:41 +01002061 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2062 self.ser.addInputTensor(incorrect_iter)
2063 self.ser.addInputTensor(a)
2064 self.ser.addInputTensor(incorrect_acc)
2065 else:
2066 self.ser.addInputTensor(iter)
2067 self.ser.addInputTensor(a)
2068 self.ser.addInputTensor(acc)
2069
Kevin Cheng550ccc52021-03-03 11:21:43 -08002070 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002071
2072 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002073 iter_body_out = self.ser.addIntermediate(
2074 incorrect_iter.shape, incorrect_iter.dtype
2075 )
2076 acc_body_out = self.ser.addIntermediate(
2077 incorrect_acc.shape, incorrect_acc.dtype
2078 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002079 else:
2080 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2081 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2082
Eric Kunzee5e26762020-10-13 16:11:07 -07002083 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2084 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2085 self.ser.addOutputTensor(iter_body_out)
2086 self.ser.addOutputTensor(a)
2087 self.ser.addOutputTensor(acc_body_out)
2088
Les Bell729b0352021-11-24 10:28:21 +00002089 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002090 self.ser,
2091 validator_fcns,
2092 error_name,
2093 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002094 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002095 ):
2096 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002097
Eric Kunzee5e26762020-10-13 16:11:07 -07002098 return acc_out
2099
Luke Hutton57287132023-02-06 14:54:18 +00002100 def build_fft2d(
2101 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2102 ):
2103 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2104
2105 input_names = [val1.name, val2.name]
2106 pCount, cCount = op["operands"]
2107 num_operands = pCount + cCount
2108
2109 output_names = [res.name for res in results]
2110 output_shapes = [res.shape for res in results]
2111 output_dtypes = [res.dtype for res in results]
2112
2113 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2114 self, error_name, input_names, output_names
2115 )
2116
2117 if not TosaErrorValidator.evValidateErrorIfs(
2118 self.ser,
2119 validator_fcns,
2120 error_name,
2121 op=op,
2122 inverse=inverse,
2123 input1=val1,
2124 input2=val2,
2125 input_shape=val1.shape,
2126 input_dtype=val1.dtype,
2127 output_shape=output_shapes,
2128 output_dtype=output_dtypes,
2129 result_tensors=results,
2130 input_list=input_names,
2131 output_list=output_names,
2132 num_operands=num_operands,
2133 ):
2134 return None
2135
2136 attr = ts.TosaSerializerAttribute()
2137 attr.FFTAttribute(inverse)
2138
2139 self.ser.addOperator(op["op"], input_names, output_names, attr)
2140 return results
2141
Luke Hutton261b7b62023-01-10 14:50:31 +00002142 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2143 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2144
2145 input_names = [val.name]
2146 pCount, cCount = op["operands"]
2147 num_operands = pCount + cCount
2148
2149 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002150 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002151 output_dtypes = [res.dtype for res in results]
2152
2153 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2154 self, error_name, input_names, output_names
2155 )
2156
2157 if not TosaErrorValidator.evValidateErrorIfs(
2158 self.ser,
2159 validator_fcns,
2160 error_name,
2161 op=op,
2162 input_shape=val.shape,
2163 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002164 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002165 output_dtype=output_dtypes,
2166 result_tensors=results,
2167 input_list=input_names,
2168 output_list=output_names,
2169 num_operands=num_operands,
2170 ):
2171 return None
2172
2173 self.ser.addOperator(op["op"], input_names, output_names)
2174 return results
2175
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002176 def create_filter_lists(
2177 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2178 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002179 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2180 default_test_rank_range = range(1, 5)
2181 if not shapeFilter:
2182 shapeFilter = [None]
2183
2184 # Calculate the filters based on what is requested and what the operator allows
2185 rmin, rmax = op["rank"]
2186 if rankFilter is not None:
2187 cleanRankFilter = []
2188 # Ensure rankFilter values are allowed by operator
2189 for rank in rankFilter:
2190 if rank >= rmin and rank <= rmax:
2191 cleanRankFilter.append(rank)
2192 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002193 # Ensure default behaviour is bounded by default range or by operator,
2194 # whichever is the smaller range of ranks.
2195 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002196 cleanRankFilter = (
2197 opRankRange
2198 if len(opRankRange) <= len(default_test_rank_range)
2199 else default_test_rank_range
2200 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002201 else:
2202 cleanRankFilter = range(rmin, rmax + 1)
2203
2204 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002205
Matthew Haddon1c00b712021-10-01 15:51:03 +01002206 if dtypeFilter is not None:
2207 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002208 # Create list of operator dtypes filtered by requested dtypes
2209 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002210 if dtype in dtypeFilter or (
2211 isinstance(dtype, list) and dtype[0] in dtypeFilter
2212 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002213 cleanDtypeFilter.append(dtype)
2214 else:
2215 cleanDtypeFilter = dtypes
2216
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002217 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002218 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002219 "shapeFilter": shapeFilter,
2220 "rankFilter": cleanRankFilter,
2221 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002222 }
2223 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002224 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002225 if validator is not None:
2226 validator_info = validator(check=False, op=op)
2227 else:
2228 return None
2229
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002230 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002231
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002232 # Set parameters as required
2233 if error_arguments["rank"] is not None:
2234 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002235 else:
2236 rankFilter = cleanRankFilter
2237
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002238 if error_arguments["dtype"] is not None:
2239 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002240 else:
2241 dtypeFilter = cleanDtypeFilter
2242
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002243 if error_arguments["shape"] is not None:
2244 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002245 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002246 shapeFilter = shapeFilter[
2247 :2
2248 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002249
2250 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002251 "shapeFilter": shapeFilter,
2252 "rankFilter": rankFilter,
2253 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002254 }
2255 return filterDict
2256
Kevin Cheng550ccc52021-03-03 11:21:43 -08002257 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002258 self,
2259 opName,
2260 shapeFilter=[None],
2261 rankFilter=None,
2262 dtypeFilter=None,
2263 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002264 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002265
2266 try:
2267 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002268 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002269 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002270
2271 # Initialize a new random number generator
2272 self.rng = np.random.default_rng(self.random_seed)
2273
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002274 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002275
Eric Kunzee5e26762020-10-13 16:11:07 -07002276 # Test list consists of a tuple of:
2277 # (opName, testNameStr, dtype, shapeList, argumentsList)
2278 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002279 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002280 error_if_validators = op["error_if_validators"]
2281 else:
2282 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002283
Matthew Haddon1c00b712021-10-01 15:51:03 +01002284 for validator in error_if_validators:
2285 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002286 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002287 else:
2288 error_name = None
2289
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002290 filterDict = self.create_filter_lists(
2291 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2292 )
2293 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002294 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002295 cleanRankFilter = filterDict["rankFilter"]
2296 cleanDtypeFilter = filterDict["dtypeFilter"]
2297 cleanShapeFilter = filterDict["shapeFilter"]
2298 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002299
2300 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002301 for t in cleanDtypeFilter:
2302 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002303 # Filter out by rank
2304 if shape is not None and len(shape) != r:
2305 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002306 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002307 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002308
Matthew Haddon74567092021-07-16 15:38:20 +01002309 shapeStr = self.shapeStr(shapeList[0])
2310 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002311
Matthew Haddon74567092021-07-16 15:38:20 +01002312 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2313 argList = []
2314 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002315 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002316 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002317 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002318
Matthew Haddon74567092021-07-16 15:38:20 +01002319 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002320 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002321 if argStr:
2322 testStr = "{}_{}_{}_{}".format(
2323 opName, shapeStr, typeStr, argStr
2324 )
2325 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002326 testStr = "{}_{}_{}".format(
2327 opName, shapeStr, typeStr
2328 )
2329 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002330 if argStr:
2331 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2332 opName, error_name, shapeStr, typeStr, argStr
2333 )
2334 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002335 testStr = "{}_ERRORIF_{}_{}_{}".format(
2336 opName, error_name, shapeStr, typeStr
2337 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002338
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002339 testList.append(
2340 (opName, testStr, t, error_name, shapeList, args)
2341 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002342
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002343 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002344 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2345 if "invalid_test_validators" in op:
2346 invalid_test_validators = op["invalid_test_validators"]
2347 clean_testList = []
2348 for test in testList:
2349 for validator_fcn in invalid_test_validators:
2350 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002351 if validator_fcn(
2352 opName=test[0],
2353 input_dtype=test[2],
2354 shapeList=test[4],
2355 args=test[5],
2356 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002357 remove_test = True
2358 if not remove_test:
2359 clean_testList.append(test)
2360 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002361
2362 return testList
2363
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002364 def serializeTest(
2365 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2366 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002367 try:
2368 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002369 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002370 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002371
2372 # Create a serializer
2373 self.createSerializer(opName, testStr)
2374
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002375 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002376 if "error_if_validators" in op:
2377 error_if_validators = op["error_if_validators"]
2378 else:
2379 error_if_validators = None
2380
Kevin Cheng550ccc52021-03-03 11:21:43 -08002381 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002382 num_operands = pCount + cCount
2383
2384 if isinstance(dtype_or_dtypeList, list):
2385 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002386 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002387 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002388 else:
2389 dtypeList = [dtype_or_dtypeList] * (num_operands)
2390
Kevin Cheng93a16282021-08-31 16:14:03 -07002391 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002392 assert (
2393 len(shapeList) == num_operands
2394 ), "shapeList length {} must match number of operands {}".format(
2395 len(shapeList), num_operands
2396 )
2397 assert (
2398 len(dtypeList) == num_operands
2399 ), "dtypeList length {} must match number of operands {}".format(
2400 len(dtypeList), num_operands
2401 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002402
2403 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002404 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002405 except KeyError:
2406 qgen = None
2407
2408 # Build the random tensor operands and the test
2409 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002410
Matthew Haddon1c00b712021-10-01 15:51:03 +01002411 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002412 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002413 else:
2414 qinfo = None
2415
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002416 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002417
Matthew Haddon1c00b712021-10-01 15:51:03 +01002418 try:
2419 if error_if_validators is None:
2420 if qinfo is not None:
2421 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2422 else:
2423 resultName = build_fcn(self, op, *tens, *testArgs)
2424 else:
2425 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002426 resultName = build_fcn(
2427 self,
2428 op,
2429 *tens,
2430 *testArgs,
2431 validator_fcns=error_if_validators,
2432 error_name=error_name,
2433 qinfo=qinfo,
2434 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002435 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002436 resultName = build_fcn(
2437 self,
2438 op,
2439 *tens,
2440 *testArgs,
2441 validator_fcns=error_if_validators,
2442 error_name=error_name,
2443 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002444 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002445 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002446 raise e
2447
Les Bell729b0352021-11-24 10:28:21 +00002448 if resultName:
2449 # The test is valid, serialize it
2450 self.serialize("test")
2451 else:
2452 # The test is not valid
2453 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002454
Eric Kunzee5e26762020-10-13 16:11:07 -07002455 def createDynamicOpLists(self):
2456
Jeremy Johnson00423432022-09-12 17:27:37 +01002457 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2458 # Already created these lists (can occur when class is initialized more than once)
2459 return
2460
Eric Kunzee5e26762020-10-13 16:11:07 -07002461 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002462 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002463
Kevin Cheng1533b852021-09-01 12:51:58 -07002464 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002465 testName = "conv2d_{}x{}".format(k[0], k[1])
2466 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2467 self.TOSA_OP_LIST[testName]["filter"] = k
2468 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002469
Kevin Cheng550ccc52021-03-03 11:21:43 -08002470 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2471 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2472 "depthwise_conv2d_TEMPLATE"
2473 ].copy()
2474 self.TOSA_OP_LIST[testName]["filter"] = k
2475 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002476
Kevin Cheng550ccc52021-03-03 11:21:43 -08002477 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2478 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2479 "transpose_conv2d_TEMPLATE"
2480 ].copy()
2481 self.TOSA_OP_LIST[testName]["filter"] = k
2482 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002483
Kevin Cheng1533b852021-09-01 12:51:58 -07002484 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2485 for k in KERNELS_3D:
2486 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2487 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2488 self.TOSA_OP_LIST[testName]["filter"] = k
2489 self.TOSA_OP_LIST[testName]["template"] = False
2490
Eric Kunzee5e26762020-10-13 16:11:07 -07002491 # Delete any templates after having created any dynamic ops
2492 # This is a two-pass operation because it's bad practice to delete
2493 # keys from dictionaries while iterating
2494 keyList = []
2495 for k in self.TOSA_OP_LIST:
2496 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002497 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002498 keyList.append(k)
2499 continue
2500 except KeyError:
2501 pass
2502
2503 for k in keyList:
2504 del self.TOSA_OP_LIST[k]
2505
2506 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002507 """Fill in default fields for ops if they aren't already specified.
2508 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002509 for op in self.TOSA_OP_LIST:
2510
2511 # Required fields
2512 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002513 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002514 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002515 raise Exception(
2516 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2517 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002518
2519 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002520 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002521 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002522 raise Exception(
2523 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2524 op
2525 )
2526 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002527
2528 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002529 _ = self.TOSA_OP_LIST[op]["types"]
2530 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002531 raise Exception(
2532 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2533 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002534
2535 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002536 _ = self.TOSA_OP_LIST[op]["op"]
2537 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002538 raise Exception(
2539 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2540 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002541
2542 # Put in default rank range, if missing
2543 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002544 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002545 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002546 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002547
2548 # Tensor operator list
2549 # 'op': op name
2550 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002551 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2552 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002553 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2554 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002555 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002556
Kevin Cheng550ccc52021-03-03 11:21:43 -08002557 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002558 TYPE_INT_FP = [
2559 DType.INT8,
2560 DType.INT16,
2561 DType.INT32,
2562 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002563 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002564 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002565 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002566
Kevin Cheng550ccc52021-03-03 11:21:43 -08002567 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002568 TYPE_FI32 = [
2569 DType.FP32,
2570 DType.FP16,
2571 DType.BF16,
2572 DType.INT32,
2573 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002574 TYPE_FIB = [
2575 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002576 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002577 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002578 DType.INT8,
2579 DType.INT16,
2580 DType.INT32,
2581 DType.BOOL,
2582 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002583 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002584
James Ward24dbc422022-10-19 12:20:31 +01002585 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002586
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002587 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002588 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002589 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002590 [DType.INT8, DType.INT8, DType.INT32],
2591 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002592 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002593 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002594 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002595 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002596 ]
2597
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002598 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002599
2600 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002601 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002602 "argmax": {
2603 "op": Op.ARGMAX,
2604 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002605 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002606 "build_fcn": (
2607 build_argmax,
2608 TosaTensorGen.tgBasic,
2609 TosaTensorValuesGen.tvgDefault,
2610 TosaArgGen.agAxis,
2611 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002612 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002613 "error_if_validators": (
2614 TosaErrorValidator.evAxisSmallerZero,
2615 TosaErrorValidator.evAxisLargerRank,
2616 TosaErrorValidator.evArgmaxOutputRankMismatch,
2617 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2618 TosaErrorValidator.evWrongRank,
2619 TosaErrorValidator.evWrongInputType,
2620 TosaErrorValidator.evWrongOutputType,
2621 TosaErrorValidator.evWrongInputList,
2622 TosaErrorValidator.evWrongOutputList,
2623 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002624 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002625 "avg_pool2d": {
2626 "op": Op.AVG_POOL2D,
2627 "operands": (1, 0),
2628 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002629 "build_fcn": (
2630 build_pool2d,
2631 TosaTensorGen.tgNHWC,
2632 TosaTensorValuesGen.tvgDefault,
2633 TosaArgGen.agPooling,
2634 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002635 "qgen": TosaQuantGen.qgUnary,
2636 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002637 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002638 "error_if_validators": (
2639 TosaErrorValidator.evKernelSmallerOne,
2640 TosaErrorValidator.evStrideSmallerOne,
2641 TosaErrorValidator.evPadSmallerZero,
2642 TosaErrorValidator.evWrongRank,
2643 TosaErrorValidator.evWrongInputType,
2644 TosaErrorValidator.evWrongOutputType,
2645 TosaErrorValidator.evWrongInputList,
2646 TosaErrorValidator.evWrongOutputList,
2647 TosaErrorValidator.evInputZeroPointNotZero,
2648 TosaErrorValidator.evOutputZeroPointNotZero,
2649 TosaErrorValidator.evPadLargerEqualKernel,
2650 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002651 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002652 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002653 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002654 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002655 "conv2d_TEMPLATE": {
2656 "op": Op.CONV2D,
2657 "operands": (1, 2),
2658 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002659 "build_fcn": (
2660 build_conv2d,
2661 TosaTensorGen.tgConv2D,
2662 TosaTensorValuesGen.tvgDefault,
2663 TosaArgGen.agConv,
2664 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002665 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002666 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002667 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2668 "error_if_validators": (
2669 TosaErrorValidator.evWrongInputType,
2670 TosaErrorValidator.evWrongOutputType,
2671 TosaErrorValidator.evWrongInputList,
2672 TosaErrorValidator.evWrongOutputList,
2673 TosaErrorValidator.evInputZeroPointNotZero,
2674 TosaErrorValidator.evWeightZeroPointNotZero,
2675 TosaErrorValidator.evPadSmallerZero,
2676 TosaErrorValidator.evStrideSmallerOne,
2677 TosaErrorValidator.evDilationSmallerOne,
2678 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002679 TosaErrorValidator.evConvOutputShapeMismatch,
2680 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002681 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002682 "template": True,
2683 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002684 # Templated operator. Filled in by createDynamicOpLists
2685 "conv3d_TEMPLATE": {
2686 "op": Op.CONV3D,
2687 "operands": (1, 2),
2688 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002689 "build_fcn": (
2690 build_conv3d,
2691 TosaTensorGen.tgConv3D,
2692 TosaTensorValuesGen.tvgDefault,
2693 TosaArgGen.agConv,
2694 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002695 "qgen": TosaQuantGen.qgConv,
2696 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002697 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2698 "error_if_validators": (
2699 TosaErrorValidator.evWrongInputType,
2700 TosaErrorValidator.evWrongOutputType,
2701 TosaErrorValidator.evWrongInputList,
2702 TosaErrorValidator.evWrongOutputList,
2703 TosaErrorValidator.evInputZeroPointNotZero,
2704 TosaErrorValidator.evWeightZeroPointNotZero,
2705 TosaErrorValidator.evPadSmallerZero,
2706 TosaErrorValidator.evStrideSmallerOne,
2707 TosaErrorValidator.evDilationSmallerOne,
2708 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002709 TosaErrorValidator.evConvOutputShapeMismatch,
2710 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002711 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002712 "template": True,
2713 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002714 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002715 "depthwise_conv2d_TEMPLATE": {
2716 "op": Op.DEPTHWISE_CONV2D,
2717 "operands": (1, 2),
2718 "filter": [1, 1],
2719 "rank": (4, 4),
2720 "build_fcn": (
2721 build_depthwise_conv2d,
2722 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002723 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002724 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002725 ),
2726 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002727 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002728 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2729 "error_if_validators": (
2730 TosaErrorValidator.evWrongInputType,
2731 TosaErrorValidator.evWrongOutputType,
2732 TosaErrorValidator.evWrongInputList,
2733 TosaErrorValidator.evWrongOutputList,
2734 TosaErrorValidator.evInputZeroPointNotZero,
2735 TosaErrorValidator.evWeightZeroPointNotZero,
2736 TosaErrorValidator.evPadSmallerZero,
2737 TosaErrorValidator.evStrideSmallerOne,
2738 TosaErrorValidator.evDilationSmallerOne,
2739 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002740 TosaErrorValidator.evConvOutputShapeMismatch,
2741 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002742 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002743 "template": True,
2744 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002745 "fully_connected": {
2746 "op": Op.FULLY_CONNECTED,
2747 "operands": (1, 2),
2748 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002749 "build_fcn": (
2750 build_fully_connected,
2751 TosaTensorGen.tgFullyConnected,
2752 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002753 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002754 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002755 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002756 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002757 "error_if_validators": (
2758 TosaErrorValidator.evInputZeroPointNotZero,
2759 TosaErrorValidator.evWeightZeroPointNotZero,
2760 TosaErrorValidator.evWrongRank,
2761 TosaErrorValidator.evWrongInputType,
2762 TosaErrorValidator.evWrongOutputType,
2763 TosaErrorValidator.evWrongInputList,
2764 TosaErrorValidator.evWrongOutputList,
2765 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002766 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002767 "matmul": {
2768 "op": Op.MATMUL,
2769 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002770 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002771 "build_fcn": (
2772 build_matmul,
2773 TosaTensorGen.tgMatmul,
2774 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002775 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002776 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002777 "qgen": TosaQuantGen.qgMatmul,
2778 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002779 "error_if_validators": (
2780 TosaErrorValidator.evInputZeroPointNotZero,
2781 TosaErrorValidator.evWrongRank,
2782 TosaErrorValidator.evWrongInputType,
2783 TosaErrorValidator.evWrongOutputType,
2784 TosaErrorValidator.evWrongInputList,
2785 TosaErrorValidator.evWrongOutputList,
2786 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002787 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002788 "max_pool2d": {
2789 "op": Op.MAX_POOL2D,
2790 "operands": (1, 0),
2791 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002792 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002793 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002794 TosaTensorGen.tgNHWC,
2795 TosaTensorValuesGen.tvgDefault,
2796 TosaArgGen.agPooling,
2797 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002798 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002799 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002800 "error_if_validators": (
2801 TosaErrorValidator.evKernelSmallerOne,
2802 TosaErrorValidator.evStrideSmallerOne,
2803 TosaErrorValidator.evPadSmallerZero,
2804 TosaErrorValidator.evWrongRank,
2805 TosaErrorValidator.evWrongInputType,
2806 TosaErrorValidator.evWrongOutputType,
2807 TosaErrorValidator.evWrongInputList,
2808 TosaErrorValidator.evWrongOutputList,
2809 TosaErrorValidator.evPadLargerEqualKernel,
2810 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002811 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002812 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002813 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002814 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002815 "transpose_conv2d_TEMPLATE": {
2816 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002817 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002818 "rank": (4, 4),
2819 "build_fcn": (
2820 build_transpose_conv2d,
2821 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002822 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002823 TosaArgGen.agTransposeConv2D,
2824 ),
2825 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002826 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002827 "invalid_test_validators": (
2828 TosaInvalidValidator.ivHeightWidthInvalid,
2829 TosaInvalidValidator.ivNonPositiveOutputShape,
2830 ),
2831 "error_if_validators": (
2832 TosaErrorValidator.evWrongInputType,
2833 TosaErrorValidator.evWrongOutputType,
2834 TosaErrorValidator.evWrongInputList,
2835 TosaErrorValidator.evWrongOutputList,
2836 TosaErrorValidator.evInputZeroPointNotZero,
2837 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002838 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002839 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002840 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002841 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002842 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002843 "template": True,
2844 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002845 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002846 "clamp": {
2847 "op": Op.CLAMP,
2848 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002849 "build_fcn": (
2850 build_clamp,
2851 TosaTensorGen.tgBasic,
2852 TosaTensorValuesGen.tvgDefault,
2853 None,
2854 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002855 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002856 "error_if_validators": (
2857 TosaErrorValidator.evMaxSmallerMin,
2858 TosaErrorValidator.evWrongInputType,
2859 TosaErrorValidator.evWrongOutputType,
2860 TosaErrorValidator.evWrongInputList,
2861 TosaErrorValidator.evWrongOutputList,
2862 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002863 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002864 "sigmoid": {
2865 "op": Op.SIGMOID,
2866 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002867 "build_fcn": (
2868 build_sigmoid,
2869 TosaTensorGen.tgBasic,
2870 TosaTensorValuesGen.tvgDefault,
2871 None,
2872 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002873 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002874 "error_if_validators": (
2875 TosaErrorValidator.evWrongInputType,
2876 TosaErrorValidator.evWrongOutputType,
2877 TosaErrorValidator.evWrongInputList,
2878 TosaErrorValidator.evWrongOutputList,
2879 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002880 },
2881 "tanh": {
2882 "op": Op.TANH,
2883 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002884 "build_fcn": (
2885 build_tanh,
2886 TosaTensorGen.tgBasic,
2887 TosaTensorValuesGen.tvgDefault,
2888 None,
2889 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002890 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002891 "error_if_validators": (
2892 TosaErrorValidator.evWrongInputType,
2893 TosaErrorValidator.evWrongOutputType,
2894 TosaErrorValidator.evWrongInputList,
2895 TosaErrorValidator.evWrongOutputList,
2896 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002897 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002898 # Elementwise Binary Operators
2899 "add": {
2900 "op": Op.ADD,
2901 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002902 "build_fcn": (
2903 build_binary_broadcast,
2904 TosaTensorGen.tgBroadcastFuzz,
2905 TosaTensorValuesGen.tvgAddSub,
2906 None,
2907 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002908 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002909 "error_if_validators": (
2910 TosaErrorValidator.evRankMismatch,
2911 TosaErrorValidator.evWrongInputType,
2912 TosaErrorValidator.evWrongOutputType,
2913 TosaErrorValidator.evWrongInputList,
2914 TosaErrorValidator.evWrongOutputList,
2915 TosaErrorValidator.evDimensionMismatch,
2916 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002917 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002918 "arithmetic_right_shift": {
2919 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2920 "operands": (2, 0),
2921 "build_fcn": (
2922 build_arithmetic_right_shift,
2923 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002924 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002925 TosaArgGen.agArithmeticRightShift,
2926 ),
2927 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002928 "error_if_validators": (
2929 TosaErrorValidator.evRankMismatch,
2930 TosaErrorValidator.evWrongInputType,
2931 TosaErrorValidator.evWrongOutputType,
2932 TosaErrorValidator.evWrongInputList,
2933 TosaErrorValidator.evWrongOutputList,
2934 TosaErrorValidator.evDimensionMismatch,
2935 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002936 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002937 "bitwise_and": {
2938 "op": Op.BITWISE_AND,
2939 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002940 "build_fcn": (
2941 build_binary_broadcast,
2942 TosaTensorGen.tgBroadcastFuzz,
2943 TosaTensorValuesGen.tvgDefault,
2944 None,
2945 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002946 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002947 "error_if_validators": (
2948 TosaErrorValidator.evRankMismatch,
2949 TosaErrorValidator.evWrongInputType,
2950 TosaErrorValidator.evWrongOutputType,
2951 TosaErrorValidator.evWrongInputList,
2952 TosaErrorValidator.evWrongOutputList,
2953 TosaErrorValidator.evDimensionMismatch,
2954 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002955 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002956 "bitwise_or": {
2957 "op": Op.BITWISE_OR,
2958 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002959 "build_fcn": (
2960 build_binary_broadcast,
2961 TosaTensorGen.tgBroadcastFuzz,
2962 TosaTensorValuesGen.tvgDefault,
2963 None,
2964 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002965 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002966 "error_if_validators": (
2967 TosaErrorValidator.evRankMismatch,
2968 TosaErrorValidator.evWrongInputType,
2969 TosaErrorValidator.evWrongOutputType,
2970 TosaErrorValidator.evWrongInputList,
2971 TosaErrorValidator.evWrongOutputList,
2972 TosaErrorValidator.evDimensionMismatch,
2973 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002974 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002975 "bitwise_xor": {
2976 "op": Op.BITWISE_XOR,
2977 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002978 "build_fcn": (
2979 build_binary_broadcast,
2980 TosaTensorGen.tgBroadcastFuzz,
2981 TosaTensorValuesGen.tvgDefault,
2982 None,
2983 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002984 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002985 "error_if_validators": (
2986 TosaErrorValidator.evRankMismatch,
2987 TosaErrorValidator.evWrongInputType,
2988 TosaErrorValidator.evWrongOutputType,
2989 TosaErrorValidator.evWrongInputList,
2990 TosaErrorValidator.evWrongOutputList,
2991 TosaErrorValidator.evDimensionMismatch,
2992 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002993 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002994 "intdiv": {
2995 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002996 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002997 "build_fcn": (
2998 build_binary_broadcast,
2999 TosaTensorGen.tgBroadcastFuzz,
3000 TosaTensorValuesGen.tvgIntDiv,
3001 None,
3002 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003003 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003004 "error_if_validators": (
3005 TosaErrorValidator.evRankMismatch,
3006 TosaErrorValidator.evWrongInputType,
3007 TosaErrorValidator.evWrongOutputType,
3008 TosaErrorValidator.evWrongInputList,
3009 TosaErrorValidator.evWrongOutputList,
3010 TosaErrorValidator.evDimensionMismatch,
3011 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003012 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003013 "logical_and": {
3014 "op": Op.LOGICAL_AND,
3015 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003016 "build_fcn": (
3017 build_binary_broadcast,
3018 TosaTensorGen.tgBroadcastFuzz,
3019 TosaTensorValuesGen.tvgDefault,
3020 None,
3021 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003022 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003023 "error_if_validators": (
3024 TosaErrorValidator.evRankMismatch,
3025 TosaErrorValidator.evWrongInputType,
3026 TosaErrorValidator.evWrongOutputType,
3027 TosaErrorValidator.evWrongInputList,
3028 TosaErrorValidator.evWrongOutputList,
3029 TosaErrorValidator.evDimensionMismatch,
3030 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003031 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003032 "logical_left_shift": {
3033 "op": Op.LOGICAL_LEFT_SHIFT,
3034 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003035 "build_fcn": (
3036 build_binary_broadcast,
3037 TosaTensorGen.tgBroadcastFuzz,
3038 TosaTensorValuesGen.tvgLogicalShift,
3039 None,
3040 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003041 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003042 "error_if_validators": (
3043 TosaErrorValidator.evRankMismatch,
3044 TosaErrorValidator.evWrongInputType,
3045 TosaErrorValidator.evWrongOutputType,
3046 TosaErrorValidator.evWrongInputList,
3047 TosaErrorValidator.evWrongOutputList,
3048 TosaErrorValidator.evDimensionMismatch,
3049 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003050 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003051 "logical_right_shift": {
3052 "op": Op.LOGICAL_RIGHT_SHIFT,
3053 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003054 "build_fcn": (
3055 build_binary_broadcast,
3056 TosaTensorGen.tgBroadcastFuzz,
3057 TosaTensorValuesGen.tvgLogicalShift,
3058 None,
3059 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003060 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003061 "error_if_validators": (
3062 TosaErrorValidator.evRankMismatch,
3063 TosaErrorValidator.evWrongInputType,
3064 TosaErrorValidator.evWrongOutputType,
3065 TosaErrorValidator.evWrongInputList,
3066 TosaErrorValidator.evWrongOutputList,
3067 TosaErrorValidator.evDimensionMismatch,
3068 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003069 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 "logical_or": {
3071 "op": Op.LOGICAL_OR,
3072 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003073 "build_fcn": (
3074 build_binary_broadcast,
3075 TosaTensorGen.tgBroadcastFuzz,
3076 TosaTensorValuesGen.tvgDefault,
3077 None,
3078 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003079 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003080 "error_if_validators": (
3081 TosaErrorValidator.evRankMismatch,
3082 TosaErrorValidator.evWrongInputType,
3083 TosaErrorValidator.evWrongOutputType,
3084 TosaErrorValidator.evWrongInputList,
3085 TosaErrorValidator.evWrongOutputList,
3086 TosaErrorValidator.evDimensionMismatch,
3087 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003088 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003089 "logical_xor": {
3090 "op": Op.LOGICAL_XOR,
3091 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003092 "build_fcn": (
3093 build_binary_broadcast,
3094 TosaTensorGen.tgBroadcastFuzz,
3095 TosaTensorValuesGen.tvgDefault,
3096 None,
3097 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003098 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003099 "error_if_validators": (
3100 TosaErrorValidator.evRankMismatch,
3101 TosaErrorValidator.evWrongInputType,
3102 TosaErrorValidator.evWrongOutputType,
3103 TosaErrorValidator.evWrongInputList,
3104 TosaErrorValidator.evWrongOutputList,
3105 TosaErrorValidator.evDimensionMismatch,
3106 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003107 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003108 "maximum": {
3109 "op": Op.MAXIMUM,
3110 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003111 "build_fcn": (
3112 build_binary_broadcast,
3113 TosaTensorGen.tgBroadcastFuzz,
3114 TosaTensorValuesGen.tvgDefault,
3115 None,
3116 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003117 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003118 "error_if_validators": (
3119 TosaErrorValidator.evRankMismatch,
3120 TosaErrorValidator.evWrongInputType,
3121 TosaErrorValidator.evWrongOutputType,
3122 TosaErrorValidator.evWrongInputList,
3123 TosaErrorValidator.evWrongOutputList,
3124 TosaErrorValidator.evDimensionMismatch,
3125 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003127 "minimum": {
3128 "op": Op.MINIMUM,
3129 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003130 "build_fcn": (
3131 build_binary_broadcast,
3132 TosaTensorGen.tgBroadcastFuzz,
3133 TosaTensorValuesGen.tvgDefault,
3134 None,
3135 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003136 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003137 "error_if_validators": (
3138 TosaErrorValidator.evRankMismatch,
3139 TosaErrorValidator.evWrongInputType,
3140 TosaErrorValidator.evWrongOutputType,
3141 TosaErrorValidator.evWrongInputList,
3142 TosaErrorValidator.evWrongOutputList,
3143 TosaErrorValidator.evDimensionMismatch,
3144 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003145 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003146 "mul": {
3147 "op": Op.MUL,
3148 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003149 "build_fcn": (
3150 build_mul,
3151 TosaTensorGen.tgBroadcastFuzz,
3152 TosaTensorValuesGen.tvgMul,
3153 TosaArgGen.agMul,
3154 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003155 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003156 "error_if_validators": (
3157 TosaErrorValidator.evWrongInputType,
3158 TosaErrorValidator.evWrongOutputType,
3159 TosaErrorValidator.evWrongInputList,
3160 TosaErrorValidator.evWrongOutputList,
3161 TosaErrorValidator.evRankMismatch,
3162 TosaErrorValidator.evDimensionMismatch,
3163 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003165 "pow": {
3166 "op": Op.POW,
3167 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003168 "build_fcn": (
3169 build_binary_broadcast,
3170 TosaTensorGen.tgBroadcastFuzz,
3171 TosaTensorValuesGen.tvgDefault,
3172 None,
3173 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003174 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003175 "error_if_validators": (
3176 TosaErrorValidator.evRankMismatch,
3177 TosaErrorValidator.evWrongInputType,
3178 TosaErrorValidator.evWrongOutputType,
3179 TosaErrorValidator.evWrongInputList,
3180 TosaErrorValidator.evWrongOutputList,
3181 TosaErrorValidator.evDimensionMismatch,
3182 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003183 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003184 "sub": {
3185 "op": Op.SUB,
3186 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003187 "build_fcn": (
3188 build_binary_broadcast,
3189 TosaTensorGen.tgBroadcastFuzz,
3190 TosaTensorValuesGen.tvgAddSub,
3191 None,
3192 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003193 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003194 "error_if_validators": (
3195 TosaErrorValidator.evRankMismatch,
3196 TosaErrorValidator.evWrongInputType,
3197 TosaErrorValidator.evWrongOutputType,
3198 TosaErrorValidator.evWrongInputList,
3199 TosaErrorValidator.evWrongOutputList,
3200 TosaErrorValidator.evDimensionMismatch,
3201 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003202 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003203 "table": {
3204 "op": Op.TABLE,
3205 # Use the automatic generation functions to create the input array
3206 # but create the table tensor in the build function, as it may be
3207 # a different type from the input
3208 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003209 "build_fcn": (
3210 build_table,
3211 TosaTensorGen.tgBasic,
3212 TosaTensorValuesGen.tvgDefault,
3213 TosaArgGen.agTable,
3214 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003215 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003216 "error_if_validators": (
3217 TosaErrorValidator.evWrongInputType,
3218 TosaErrorValidator.evWrongOutputType,
3219 TosaErrorValidator.evWrongInputList,
3220 TosaErrorValidator.evWrongOutputList,
3221 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003222 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003223 # Elementwise Unary operators
3224 "abs": {
3225 "op": Op.ABS,
3226 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003227 "build_fcn": (
3228 build_unary,
3229 TosaTensorGen.tgBasic,
3230 TosaTensorValuesGen.tvgDefault,
3231 None,
3232 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003233 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003234 "error_if_validators": (
3235 TosaErrorValidator.evWrongInputType,
3236 TosaErrorValidator.evWrongOutputType,
3237 TosaErrorValidator.evWrongInputList,
3238 TosaErrorValidator.evWrongOutputList,
3239 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003240 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003241 "bitwise_not": {
3242 "op": Op.BITWISE_NOT,
3243 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003244 "build_fcn": (
3245 build_unary,
3246 TosaTensorGen.tgBasic,
3247 TosaTensorValuesGen.tvgDefault,
3248 None,
3249 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003250 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003251 "error_if_validators": (
3252 TosaErrorValidator.evWrongInputType,
3253 TosaErrorValidator.evWrongOutputType,
3254 TosaErrorValidator.evWrongInputList,
3255 TosaErrorValidator.evWrongOutputList,
3256 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003257 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003258 "ceil": {
3259 "op": Op.CEIL,
3260 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003261 "build_fcn": (
3262 build_unary,
3263 TosaTensorGen.tgBasic,
3264 TosaTensorValuesGen.tvgDefault,
3265 None,
3266 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003267 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003268 "error_if_validators": (
3269 TosaErrorValidator.evWrongInputType,
3270 TosaErrorValidator.evWrongOutputType,
3271 TosaErrorValidator.evWrongInputList,
3272 TosaErrorValidator.evWrongOutputList,
3273 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003274 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003275 "clz": {
3276 "op": Op.CLZ,
3277 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003278 "build_fcn": (
3279 build_unary,
3280 TosaTensorGen.tgBasic,
3281 TosaTensorValuesGen.tvgDefault,
3282 None,
3283 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003284 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003285 "error_if_validators": (
3286 TosaErrorValidator.evWrongInputType,
3287 TosaErrorValidator.evWrongOutputType,
3288 TosaErrorValidator.evWrongInputList,
3289 TosaErrorValidator.evWrongOutputList,
3290 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003291 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003292 "exp": {
3293 "op": Op.EXP,
3294 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003295 "build_fcn": (
3296 build_unary,
3297 TosaTensorGen.tgBasic,
3298 TosaTensorValuesGen.tvgDefault,
3299 None,
3300 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003301 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003302 "error_if_validators": (
3303 TosaErrorValidator.evWrongInputType,
3304 TosaErrorValidator.evWrongOutputType,
3305 TosaErrorValidator.evWrongInputList,
3306 TosaErrorValidator.evWrongOutputList,
3307 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003308 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 "floor": {
3310 "op": Op.FLOOR,
3311 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 "build_fcn": (
3313 build_unary,
3314 TosaTensorGen.tgBasic,
3315 TosaTensorValuesGen.tvgDefault,
3316 None,
3317 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003318 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003319 "error_if_validators": (
3320 TosaErrorValidator.evWrongInputType,
3321 TosaErrorValidator.evWrongOutputType,
3322 TosaErrorValidator.evWrongInputList,
3323 TosaErrorValidator.evWrongOutputList,
3324 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003325 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003326 "log": {
3327 "op": Op.LOG,
3328 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003329 "build_fcn": (
3330 build_unary,
3331 TosaTensorGen.tgBasic,
3332 TosaTensorValuesGen.tvgDefault,
3333 None,
3334 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003335 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003336 "error_if_validators": (
3337 TosaErrorValidator.evWrongInputType,
3338 TosaErrorValidator.evWrongOutputType,
3339 TosaErrorValidator.evWrongInputList,
3340 TosaErrorValidator.evWrongOutputList,
3341 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003342 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 "logical_not": {
3344 "op": Op.LOGICAL_NOT,
3345 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003346 "build_fcn": (
3347 build_unary,
3348 TosaTensorGen.tgBasic,
3349 TosaTensorValuesGen.tvgDefault,
3350 None,
3351 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003352 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003353 "error_if_validators": (
3354 TosaErrorValidator.evWrongInputType,
3355 TosaErrorValidator.evWrongOutputType,
3356 TosaErrorValidator.evWrongInputList,
3357 TosaErrorValidator.evWrongOutputList,
3358 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003359 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 "negate": {
3361 "op": Op.NEGATE,
3362 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003363 "build_fcn": (
3364 build_unary,
3365 TosaTensorGen.tgBasic,
3366 TosaTensorValuesGen.tvgNegate,
3367 None,
3368 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003369 "qgen": TosaQuantGen.qgUnary,
3370 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003371 "error_if_validators": (
3372 TosaErrorValidator.evInputZeroPointNotZero,
3373 TosaErrorValidator.evOutputZeroPointNotZero,
3374 TosaErrorValidator.evWrongInputType,
3375 TosaErrorValidator.evWrongOutputType,
3376 TosaErrorValidator.evWrongInputList,
3377 TosaErrorValidator.evWrongOutputList,
3378 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003379 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003380 "reciprocal": {
3381 "op": Op.RECIPROCAL,
3382 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003383 "build_fcn": (
3384 build_unary,
3385 TosaTensorGen.tgBasic,
3386 TosaTensorValuesGen.tvgDefault,
3387 None,
3388 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003389 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003390 "error_if_validators": (
3391 TosaErrorValidator.evWrongInputType,
3392 TosaErrorValidator.evWrongOutputType,
3393 TosaErrorValidator.evWrongInputList,
3394 TosaErrorValidator.evWrongOutputList,
3395 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003397 "rsqrt": {
3398 "op": Op.RSQRT,
3399 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003400 "build_fcn": (
3401 build_unary,
3402 TosaTensorGen.tgBasic,
3403 TosaTensorValuesGen.tvgDefault,
3404 None,
3405 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003407 "error_if_validators": (
3408 TosaErrorValidator.evWrongInputType,
3409 TosaErrorValidator.evWrongOutputType,
3410 TosaErrorValidator.evWrongInputList,
3411 TosaErrorValidator.evWrongOutputList,
3412 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003414 # Elementwise Ternary operators
3415 "select": {
3416 "op": Op.SELECT,
3417 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003418 "build_fcn": (
3419 build_select,
3420 TosaTensorGen.tgBroadcastFuzz,
3421 TosaTensorValuesGen.tvgSelect,
3422 None,
3423 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003424 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003425 "error_if_validators": (
3426 TosaErrorValidator.evRankMismatch,
3427 TosaErrorValidator.evWrongInputType,
3428 TosaErrorValidator.evWrongOutputType,
3429 TosaErrorValidator.evWrongInputList,
3430 TosaErrorValidator.evWrongOutputList,
3431 TosaErrorValidator.evDimensionMismatch,
3432 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 # Comparison operators
3435 "equal": {
3436 "op": Op.EQUAL,
3437 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003438 "build_fcn": (
3439 build_comparison,
3440 TosaTensorGen.tgBroadcastFuzz,
3441 TosaTensorValuesGen.tvgEqual,
3442 None,
3443 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003444 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003445 "error_if_validators": (
3446 TosaErrorValidator.evRankMismatch,
3447 TosaErrorValidator.evWrongInputType,
3448 TosaErrorValidator.evWrongOutputType,
3449 TosaErrorValidator.evWrongInputList,
3450 TosaErrorValidator.evWrongOutputList,
3451 TosaErrorValidator.evDimensionMismatch,
3452 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003454 "greater_equal": {
3455 "op": Op.GREATER_EQUAL,
3456 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003457 "build_fcn": (
3458 build_comparison,
3459 TosaTensorGen.tgBroadcastFuzz,
3460 TosaTensorValuesGen.tvgDefault,
3461 None,
3462 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003464 "error_if_validators": (
3465 TosaErrorValidator.evRankMismatch,
3466 TosaErrorValidator.evWrongInputType,
3467 TosaErrorValidator.evWrongOutputType,
3468 TosaErrorValidator.evWrongInputList,
3469 TosaErrorValidator.evWrongOutputList,
3470 TosaErrorValidator.evDimensionMismatch,
3471 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003472 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "greater": {
3474 "op": Op.GREATER,
3475 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003476 "build_fcn": (
3477 build_comparison,
3478 TosaTensorGen.tgBroadcastFuzz,
3479 TosaTensorValuesGen.tvgDefault,
3480 None,
3481 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003482 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003483 "error_if_validators": (
3484 TosaErrorValidator.evRankMismatch,
3485 TosaErrorValidator.evWrongInputType,
3486 TosaErrorValidator.evWrongOutputType,
3487 TosaErrorValidator.evWrongInputList,
3488 TosaErrorValidator.evWrongOutputList,
3489 TosaErrorValidator.evDimensionMismatch,
3490 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003491 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003492 # Reduction operators
3493 "reduce_all": {
3494 "op": Op.REDUCE_ALL,
3495 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003496 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003497 "build_fcn": (
3498 build_reduce,
3499 TosaTensorGen.tgBasic,
3500 TosaTensorValuesGen.tvgDefault,
3501 TosaArgGen.agAxis,
3502 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003503 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003504 "error_if_validators": (
3505 TosaErrorValidator.evAxisLargerRank,
3506 TosaErrorValidator.evAxisSmallerZero,
3507 TosaErrorValidator.evShapeOfAxisNotOne,
3508 TosaErrorValidator.evWrongInputType,
3509 TosaErrorValidator.evWrongOutputType,
3510 TosaErrorValidator.evWrongRank,
3511 TosaErrorValidator.evWrongInputList,
3512 TosaErrorValidator.evWrongOutputList,
3513 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003514 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 "reduce_any": {
3516 "op": Op.REDUCE_ANY,
3517 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003518 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 "build_fcn": (
3520 build_reduce,
3521 TosaTensorGen.tgBasic,
3522 TosaTensorValuesGen.tvgDefault,
3523 TosaArgGen.agAxis,
3524 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003526 "error_if_validators": (
3527 TosaErrorValidator.evAxisLargerRank,
3528 TosaErrorValidator.evAxisSmallerZero,
3529 TosaErrorValidator.evShapeOfAxisNotOne,
3530 TosaErrorValidator.evWrongInputType,
3531 TosaErrorValidator.evWrongOutputType,
3532 TosaErrorValidator.evWrongRank,
3533 TosaErrorValidator.evWrongInputList,
3534 TosaErrorValidator.evWrongOutputList,
3535 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003537 "reduce_max": {
3538 "op": Op.REDUCE_MAX,
3539 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003540 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003541 "build_fcn": (
3542 build_reduce,
3543 TosaTensorGen.tgBasic,
3544 TosaTensorValuesGen.tvgDefault,
3545 TosaArgGen.agAxis,
3546 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003547 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003548 "error_if_validators": (
3549 TosaErrorValidator.evAxisLargerRank,
3550 TosaErrorValidator.evAxisSmallerZero,
3551 TosaErrorValidator.evShapeOfAxisNotOne,
3552 TosaErrorValidator.evWrongInputType,
3553 TosaErrorValidator.evWrongOutputType,
3554 TosaErrorValidator.evWrongRank,
3555 TosaErrorValidator.evWrongInputList,
3556 TosaErrorValidator.evWrongOutputList,
3557 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003560 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003561 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003562 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003563 "build_fcn": (
3564 build_reduce,
3565 TosaTensorGen.tgBasic,
3566 TosaTensorValuesGen.tvgDefault,
3567 TosaArgGen.agAxis,
3568 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003569 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003570 "error_if_validators": (
3571 TosaErrorValidator.evAxisLargerRank,
3572 TosaErrorValidator.evAxisSmallerZero,
3573 TosaErrorValidator.evShapeOfAxisNotOne,
3574 TosaErrorValidator.evWrongInputType,
3575 TosaErrorValidator.evWrongOutputType,
3576 TosaErrorValidator.evWrongRank,
3577 TosaErrorValidator.evWrongInputList,
3578 TosaErrorValidator.evWrongOutputList,
3579 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003580 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 "reduce_product": {
3582 "op": Op.REDUCE_PRODUCT,
3583 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003584 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003585 "build_fcn": (
3586 build_reduce,
3587 TosaTensorGen.tgBasic,
3588 TosaTensorValuesGen.tvgDefault,
3589 TosaArgGen.agAxis,
3590 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003591 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003592 "error_if_validators": (
3593 TosaErrorValidator.evAxisLargerRank,
3594 TosaErrorValidator.evAxisSmallerZero,
3595 TosaErrorValidator.evShapeOfAxisNotOne,
3596 TosaErrorValidator.evWrongInputType,
3597 TosaErrorValidator.evWrongOutputType,
3598 TosaErrorValidator.evWrongRank,
3599 TosaErrorValidator.evWrongInputList,
3600 TosaErrorValidator.evWrongOutputList,
3601 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003602 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003603 "reduce_sum": {
3604 "op": Op.REDUCE_SUM,
3605 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003606 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003607 "build_fcn": (
3608 build_reduce,
3609 TosaTensorGen.tgBasic,
3610 TosaTensorValuesGen.tvgReduceSum,
3611 TosaArgGen.agAxis,
3612 ),
James Ward24dbc422022-10-19 12:20:31 +01003613 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003614 "error_if_validators": (
3615 TosaErrorValidator.evAxisLargerRank,
3616 TosaErrorValidator.evAxisSmallerZero,
3617 TosaErrorValidator.evShapeOfAxisNotOne,
3618 TosaErrorValidator.evWrongInputType,
3619 TosaErrorValidator.evWrongOutputType,
3620 TosaErrorValidator.evWrongRank,
3621 TosaErrorValidator.evWrongInputList,
3622 TosaErrorValidator.evWrongOutputList,
3623 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003624 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003625 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003626 "concat": {
3627 "op": Op.CONCAT,
3628 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003629 "build_fcn": (
3630 build_concat,
3631 TosaTensorGen.tgConcat,
3632 TosaTensorValuesGen.tvgConcat,
3633 TosaArgGen.agAxis,
3634 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003635 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003636 "error_if_validators": (
3637 TosaErrorValidator.evAxisLargerRank,
3638 TosaErrorValidator.evAxisSmallerZero,
3639 TosaErrorValidator.evConcatInputRankMismatch,
3640 TosaErrorValidator.evConcatShapeSumMismatch,
3641 TosaErrorValidator.evConcatInputDimMismatch,
3642 TosaErrorValidator.evWrongInputType,
3643 TosaErrorValidator.evWrongOutputType,
3644 TosaErrorValidator.evWrongOutputList,
3645 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003646 },
3647 "pad": {
3648 "op": Op.PAD,
3649 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003650 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003651 "build_fcn": (
3652 build_pad,
3653 TosaTensorGen.tgBasic,
3654 TosaTensorValuesGen.tvgDefault,
3655 TosaArgGen.agPad,
3656 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003657 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003658 "error_if_validators": (
3659 TosaErrorValidator.evWrongInputType,
3660 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003661 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003662 TosaErrorValidator.evWrongOutputType,
3663 TosaErrorValidator.evWrongInputList,
3664 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003665 TosaErrorValidator.evRankMismatch,
3666 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003667 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003668 },
3669 "reshape": {
3670 "op": Op.RESHAPE,
3671 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 "build_fcn": (
3673 build_reshape,
3674 TosaTensorGen.tgBasic,
3675 TosaTensorValuesGen.tvgDefault,
3676 TosaArgGen.agReshape,
3677 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003678 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 "error_if_validators": (
3680 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003686 },
3687 "reverse": {
3688 "op": Op.REVERSE,
3689 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003690 "build_fcn": (
3691 build_reverse,
3692 TosaTensorGen.tgBasic,
3693 TosaTensorValuesGen.tvgDefault,
3694 TosaArgGen.agAxis,
3695 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003696 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003697 "error_if_validators": (
3698 TosaErrorValidator.evAxisSmallerZero,
3699 TosaErrorValidator.evAxisLargerRank,
3700 TosaErrorValidator.evWrongInputType,
3701 TosaErrorValidator.evWrongOutputType,
3702 TosaErrorValidator.evWrongInputList,
3703 TosaErrorValidator.evWrongOutputList,
3704 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003705 },
3706 "slice": {
3707 "op": Op.SLICE,
3708 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003709 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003710 "build_fcn": (
3711 build_slice,
3712 TosaTensorGen.tgBasic,
3713 TosaTensorValuesGen.tvgDefault,
3714 TosaArgGen.agSlice,
3715 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003716 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003717 "error_if_validators": (
3718 TosaErrorValidator.evStartSmallerZero,
3719 TosaErrorValidator.evSizeSmallerEqualZero,
3720 TosaErrorValidator.evStartSizeOutsideBounds,
3721 TosaErrorValidator.evSizeOutputShapeMismatch,
3722 TosaErrorValidator.evInputSizeStartLengthMismatch,
3723 TosaErrorValidator.evWrongRank,
3724 TosaErrorValidator.evWrongInputType,
3725 TosaErrorValidator.evWrongOutputType,
3726 TosaErrorValidator.evWrongInputList,
3727 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003728 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003729 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003730 },
3731 "tile": {
3732 "op": Op.TILE,
3733 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003734 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003735 "build_fcn": (
3736 build_tile,
3737 TosaTensorGen.tgBasic,
3738 TosaTensorValuesGen.tvgDefault,
3739 TosaArgGen.agTile,
3740 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003741 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 "error_if_validators": (
3743 TosaErrorValidator.evWrongInputType,
3744 TosaErrorValidator.evWrongOutputType,
3745 TosaErrorValidator.evWrongInputList,
3746 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003747 TosaErrorValidator.evRankMismatch,
3748 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003749 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003750 },
3751 "transpose": {
3752 "op": Op.TRANSPOSE,
3753 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003754 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003755 "build_fcn": (
3756 build_transpose,
3757 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003758 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003759 TosaArgGen.agTranspose,
3760 ),
3761 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 "error_if_validators": (
3763 TosaErrorValidator.evIndexOutsideBounds,
3764 TosaErrorValidator.evIndexUsedTwice,
3765 TosaErrorValidator.evWrongInputType,
3766 TosaErrorValidator.evWrongOutputType,
3767 TosaErrorValidator.evWrongInputList,
3768 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003769 TosaErrorValidator.evWrongRank,
3770 TosaErrorValidator.evRankMismatch,
3771 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003772 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003773 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003774 # Data nodes
3775 "const": {
3776 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003777 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003778 "build_fcn": (
3779 build_const,
3780 TosaTensorGen.tgBasic,
3781 TosaTensorValuesGen.tvgDefault,
3782 None,
3783 ),
Luke Hutton65872422023-02-20 10:33:04 +00003784 "types": TYPE_FIB + [DType.INT48],
3785 "error_if_validators": (
3786 TosaErrorValidator.evWrongInputType,
3787 TosaErrorValidator.evWrongOutputType,
3788 TosaErrorValidator.evWrongInputList,
3789 TosaErrorValidator.evWrongOutputList,
3790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003792 "identity": {
3793 "op": Op.IDENTITY,
3794 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003795 "build_fcn": (
3796 build_unary,
3797 TosaTensorGen.tgBasic,
3798 TosaTensorValuesGen.tvgDefault,
3799 None,
3800 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003801 "types": TYPE_FIB,
3802 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003803 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003804 "gather": {
3805 "op": Op.GATHER,
3806 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3807 "operands": (1, 0),
3808 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003809 "build_fcn": (
3810 build_gather,
3811 TosaTensorGen.tgBasic,
3812 TosaTensorValuesGen.tvgDefault,
3813 None,
3814 ),
James Ward24dbc422022-10-19 12:20:31 +01003815 "types": (
3816 DType.INT8,
3817 DType.INT16,
3818 DType.INT32,
3819 DType.FP16,
3820 DType.BF16,
3821 DType.FP32,
3822 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003823 "error_if_validators": (
3824 TosaErrorValidator.evWrongInputType,
3825 TosaErrorValidator.evWrongOutputType,
3826 TosaErrorValidator.evWrongInputList,
3827 TosaErrorValidator.evWrongOutputList,
3828 TosaErrorValidator.evWrongRank,
3829 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003830 },
3831 "scatter": {
3832 "op": Op.SCATTER,
3833 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003834 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003835 "operands": (2, 0),
3836 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003837 "build_fcn": (
3838 build_scatter,
3839 TosaTensorGen.tgScatter,
3840 TosaTensorValuesGen.tvgDefault,
3841 None,
3842 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003843 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003844 "error_if_validators": (
3845 TosaErrorValidator.evWrongInputType,
3846 TosaErrorValidator.evWrongOutputType,
3847 TosaErrorValidator.evWrongInputList,
3848 TosaErrorValidator.evWrongOutputList,
3849 TosaErrorValidator.evWrongRank,
3850 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003851 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003852 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003853 "resize": {
3854 "op": Op.RESIZE,
3855 "operands": (1, 0),
3856 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003857 "build_fcn": (
3858 build_resize,
3859 TosaTensorGen.tgNHWC,
3860 TosaTensorValuesGen.tvgDefault,
3861 TosaArgGen.agResize,
3862 ),
James Ward24dbc422022-10-19 12:20:31 +01003863 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003864 "invalid_test_validators": (
3865 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003866 ),
3867 "error_if_validators": (
3868 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003869 TosaErrorValidator.evScaleSmallerEqualZero,
3870 TosaErrorValidator.evScaleNLargerMax,
3871 TosaErrorValidator.evScaleDLargerMax,
3872 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003873 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003874 TosaErrorValidator.evBorderSmallerMin,
3875 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003876 TosaErrorValidator.evWrongInputType,
3877 TosaErrorValidator.evWrongOutputType,
3878 TosaErrorValidator.evWrongRank,
3879 TosaErrorValidator.evWrongInputList,
3880 TosaErrorValidator.evWrongOutputList,
3881 TosaErrorValidator.evBatchMismatch,
3882 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003883 TosaErrorValidator.evResizeOutputShapeMismatch,
3884 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003885 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003886 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003887 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003888 "cast": {
3889 "op": Op.CAST,
3890 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003891 "build_fcn": (
3892 build_cast,
3893 TosaTensorGen.tgBasic,
3894 TosaTensorValuesGen.tvgDefault,
3895 TosaArgGen.agCast,
3896 ),
James Ward8b390432022-08-12 20:48:56 +01003897 "types": (
3898 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003899 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003900 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003901 DType.INT8,
3902 DType.INT16,
3903 DType.INT32,
3904 DType.BOOL,
3905 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003906 "error_if_validators": (
3907 TosaErrorValidator.evWrongInputType,
3908 TosaErrorValidator.evWrongOutputType,
3909 TosaErrorValidator.evWrongInputList,
3910 TosaErrorValidator.evWrongOutputList,
3911 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003912 },
3913 "rescale": {
3914 "op": Op.RESCALE,
3915 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003916 "build_fcn": (
3917 build_rescale,
3918 TosaTensorGen.tgBasic,
3919 TosaTensorValuesGen.tvgDefault,
3920 TosaArgGen.agRescale,
3921 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003922 "types": [
3923 DType.UINT8,
3924 DType.INT8,
3925 DType.INT16,
3926 DType.INT32,
3927 DType.INT48,
3928 DType.UINT16,
3929 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003930 "error_if_validators": (
3931 TosaErrorValidator.evInputZeroPointNotZero,
3932 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003933 TosaErrorValidator.evU16InputZeroPointNotValid,
3934 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003935 TosaErrorValidator.evScaleTrue,
3936 TosaErrorValidator.evScaleNotTrue,
3937 TosaErrorValidator.evWrongInputType,
3938 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 TosaErrorValidator.evWrongInputList,
3940 TosaErrorValidator.evWrongOutputList,
3941 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003942 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003943 # Custom
3944 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003945 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003946 # Two varients of cond_if, one that generates one of two constant tensors (no
3947 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3948 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003949 "cond_if_const": {
3950 "op": Op.COND_IF,
3951 "operands": (0, 2),
3952 "build_fcn": (
3953 build_cond_if_const,
3954 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003956 TosaArgGen.agCondIf,
3957 ),
3958 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003959 "error_if_validators": (
3960 TosaErrorValidator.evOutputListThenGraphMismatch,
3961 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003962 TosaErrorValidator.evCondIfCondNotMatchingBool,
3963 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003964 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003965 },
3966 "cond_if_binary": {
3967 "op": Op.COND_IF,
3968 "operands": (2, 0),
3969 "build_fcn": (
3970 build_cond_if_binary,
3971 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003972 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003973 TosaArgGen.agCondIf,
3974 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003975 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003976 "error_if_validators": (
3977 TosaErrorValidator.evInputListThenGraphMismatch,
3978 TosaErrorValidator.evInputListElseGraphMismatch,
3979 TosaErrorValidator.evOutputListThenGraphMismatch,
3980 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003981 TosaErrorValidator.evCondIfCondNotMatchingBool,
3982 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003983 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003984 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003985 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003986 "while_loop": {
3987 "op": Op.WHILE_LOOP,
3988 "operands": (0, 1),
3989 "build_fcn": (
3990 build_while_loop,
3991 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003992 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003993 TosaArgGen.agWhileLoop,
3994 ),
3995 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003996 "error_if_validators": (
3997 TosaErrorValidator.evInputListOutputListMismatch,
3998 TosaErrorValidator.evInputListCondGraphMismatch,
3999 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4000 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4001 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004002 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004003 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004004 },
Luke Hutton57287132023-02-06 14:54:18 +00004005 "fft2d": {
4006 "op": Op.FFT2D,
4007 "operands": (2, 0),
4008 "rank": (3, 3),
4009 "build_fcn": (
4010 build_fft2d,
4011 TosaTensorGen.tgFFT2d,
4012 TosaTensorValuesGen.tvgDefault,
4013 TosaArgGen.agFFT2d,
4014 ),
4015 "types": [DType.FP32],
4016 "error_if_validators": (
4017 TosaErrorValidator.evWrongInputType,
4018 TosaErrorValidator.evWrongOutputType,
4019 TosaErrorValidator.evWrongInputList,
4020 TosaErrorValidator.evWrongOutputList,
4021 TosaErrorValidator.evWrongRank,
4022 TosaErrorValidator.evBatchMismatch,
4023 TosaErrorValidator.evKernelNotPowerOfTwo,
4024 TosaErrorValidator.evFFTInputShapeMismatch,
4025 TosaErrorValidator.evFFTOutputShapeMismatch,
4026 ),
4027 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004028 "rfft2d": {
4029 "op": Op.RFFT2D,
4030 "operands": (1, 0),
4031 "rank": (3, 3),
4032 "build_fcn": (
4033 build_rfft2d,
4034 TosaTensorGen.tgRFFT2d,
4035 TosaTensorValuesGen.tvgDefault,
4036 TosaArgGen.agNone,
4037 ),
4038 "types": [DType.FP32],
4039 "error_if_validators": (
4040 TosaErrorValidator.evWrongInputType,
4041 TosaErrorValidator.evWrongOutputType,
4042 TosaErrorValidator.evWrongInputList,
4043 TosaErrorValidator.evWrongOutputList,
4044 TosaErrorValidator.evWrongRank,
4045 TosaErrorValidator.evBatchMismatch,
4046 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004047 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004048 ),
4049 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004050 }
4051
Kevin Cheng550ccc52021-03-03 11:21:43 -08004052
Eric Kunzee5e26762020-10-13 16:11:07 -07004053class OutputShaper:
4054 # Methods in this class compute the expected output shape and datatype
4055 # for common classes of operations
4056 def __init__(self):
4057 pass
4058
4059 # These methods return arguments that can be used for
4060 # creating a new output tensor
4061 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004062 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4063 if error_name != ErrorIf.RankMismatch:
4064 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004065 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004066
4067 shape = []
4068 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004069 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004070 shape.append(b.shape[i])
4071 else:
4072 shape.append(a.shape[i])
4073
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004074 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004075 all_dtypes = [
4076 DType.INT8,
4077 DType.INT16,
4078 DType.INT32,
4079 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004080 DType.FP16,
4081 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004082 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004083 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004084 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4085 outputDType = rng.choice(wrong_dtypes)
4086 else:
4087 outputDType = a.dtype
4088
4089 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004090
4091 @staticmethod
4092 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004093 assert len(a.shape) == len(b.shape)
4094 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004095
4096 shape = []
4097 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004098 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004099 shape.append(a.shape[i])
4100
Kevin Cheng550ccc52021-03-03 11:21:43 -08004101 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004102
4103 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004104 def unaryOp(ser, rng, a, error_name=None):
4105 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004106 all_dtypes = [
4107 DType.INT8,
4108 DType.INT16,
4109 DType.INT32,
4110 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004111 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004112 DType.FP16,
4113 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004114 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004115 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4116 outputDType = rng.choice(wrong_dtypes)
4117 else:
4118 outputDType = a.dtype
4119
4120 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004121
4122 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004123 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004124 if error_name != ErrorIf.RankMismatch:
4125 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004126 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004127
4128 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004129 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004130 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004131 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4132 else:
4133 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004134
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004135 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004136 all_dtypes = [
4137 DType.INT8,
4138 DType.INT16,
4139 DType.INT32,
4140 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004141 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004142 DType.FP16,
4143 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004144 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004145 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4146 outputDType = rng.choice(wrong_dtypes)
4147 else:
4148 outputDType = a.dtype
4149
4150 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004151
4152 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004153 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004154 if error_name != ErrorIf.RankMismatch:
4155 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004156 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004157
4158 # Do broadcast
4159 shape = []
4160 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004161 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004162 shape.append(b.shape[i])
4163 else:
4164 shape.append(a.shape[i])
4165
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004166 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004167 wrong_dtypes = [
4168 DType.INT8,
4169 DType.INT16,
4170 DType.INT32,
4171 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004172 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004173 DType.FP16,
4174 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004175 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004176 outputDType = rng.choice(wrong_dtypes)
4177 else:
4178 outputDType = DType.BOOL
4179
4180 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004181
4182 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004183 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004184 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004185 if error_name not in [
4186 ErrorIf.AxisSmallerZero,
4187 ErrorIf.AxisLargerRank,
4188 ErrorIf.ShapeOfAxisNotOne,
4189 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004190 shape[axis] = 1
4191 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4192 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004193
Matthew Haddond6ce7252021-09-29 15:35:44 +01004194 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004195 all_dtypes = [
4196 DType.INT8,
4197 DType.INT16,
4198 DType.INT32,
4199 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004200 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004201 DType.FP16,
4202 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004203 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004204 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4205 outputDType = rng.choice(wrong_dtypes)
4206 else:
4207 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004208
Matthew Haddond6ce7252021-09-29 15:35:44 +01004209 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004210
4211 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004212 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004213 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004214
4215 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4216 del shape[axis]
4217
4218 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4219 remove = rng.choice([True, False])
4220 if remove and len(shape) > 1:
4221 del shape[0]
4222 else:
4223 shape.append(1)
4224 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4225 for i in range(len(shape)):
4226 shape[i] = shape[i] + rng.integers(1, 10)
4227
4228 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004229 all_dtypes = [
4230 DType.INT8,
4231 DType.INT16,
4232 DType.INT32,
4233 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004234 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004235 DType.FP16,
4236 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004237 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004238 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4239 outputDType = rng.choice(wrong_dtypes)
4240 else:
4241 outputDType = DType.INT32
4242
4243 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004244
4245 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004246 def conv2dOp(
4247 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4248 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004249
4250 # IFM: NHWC
4251 # Filter: OHWI
4252 # OFM: NHWC
4253
Kevin Cheng550ccc52021-03-03 11:21:43 -08004254 h = (
4255 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004256 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004257 + padding[0]
4258 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004259 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004260 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004261
Kevin Cheng550ccc52021-03-03 11:21:43 -08004262 w = (
4263 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004264 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004265 + padding[2]
4266 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004267 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004268 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004269
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004270 if error_name == ErrorIf.ConvOutputShapeMismatch:
4271 choices = [1, 2, 3]
4272 change = rng.choice(choices)
4273 # increment in multiples of stride to not hit non-integer error case
4274 if change in [1, 3]:
4275 h = h + (rng.choice(choices) * strides[0])
4276 if change in [2, 3]:
4277 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004278
Eric Kunzee5e26762020-10-13 16:11:07 -07004279 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4280
James Ward8b390432022-08-12 20:48:56 +01004281 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004282 # Pick some potentially correct output dtype if input type is incorrect
4283 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004284 else:
James Ward8b390432022-08-12 20:48:56 +01004285 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004286
4287 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004288 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004289 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004290 else:
4291 excludes = [out_dtype]
4292 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004293 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004294
Kevin Cheng550ccc52021-03-03 11:21:43 -08004295 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004296
4297 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004298 def conv3dOp(
4299 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4300 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004301
4302 # IFM: NDHWC
4303 # Filter: ODHWI
4304 # OFM: NDHWC
4305
4306 d = (
4307 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004308 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004309 + padding[0]
4310 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004311 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004312 ) // strides[0] + 1
4313
4314 h = (
4315 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004316 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004317 + padding[2]
4318 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004319 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004320 ) // strides[1] + 1
4321
4322 w = (
4323 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004324 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004325 + padding[4]
4326 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004327 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004328 ) // strides[2] + 1
4329
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004330 if error_name == ErrorIf.ConvOutputShapeMismatch:
4331 choices = [1, 2, 3, 4]
4332 change = rng.choice(choices)
4333 # increment in multiples of stride to not hit non-integer error case
4334 if change in [1, 4]:
4335 d = d + (rng.choice(choices) * strides[0])
4336 if change in [2, 4]:
4337 h = h + (rng.choice(choices) * strides[1])
4338 if change in [3, 4]:
4339 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004340
Kevin Cheng1533b852021-09-01 12:51:58 -07004341 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4342
James Ward8b390432022-08-12 20:48:56 +01004343 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004344 # Pick some potentially correct output dtype if input type is incorrect
4345 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004346 else:
James Ward8b390432022-08-12 20:48:56 +01004347 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004348
4349 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004350 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004351 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004352 else:
4353 excludes = [out_dtype]
4354 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004355 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004356
4357 return ser.addOutput(ofm_shape, out_dtype)
4358
4359 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004360 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004361 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004362 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004363 # IFM: NHWC
4364 # Filter: HWCM
4365 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004366
Kevin Cheng550ccc52021-03-03 11:21:43 -08004367 h = (
4368 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004369 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004370 + padding[0]
4371 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004372 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004373 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004374
Kevin Cheng550ccc52021-03-03 11:21:43 -08004375 w = (
4376 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004377 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004378 + padding[2]
4379 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004380 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004381 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004382
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004383 if error_name == ErrorIf.ConvOutputShapeMismatch:
4384 choices = [1, 2, 3]
4385 change = rng.choice(choices)
4386 # increment in multiples of stride to not hit non-integer error case
4387 if change in [1, 3]:
4388 h = h + (rng.choice(choices) * strides[0])
4389 if change in [2, 3]:
4390 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004391
Eric Kunzee5e26762020-10-13 16:11:07 -07004392 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4393
James Ward8b390432022-08-12 20:48:56 +01004394 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004395 # Pick some potentially correct output dtype if input type is incorrect
4396 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004397 else:
James Ward8b390432022-08-12 20:48:56 +01004398 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004399
4400 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004401 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004402 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004403 else:
4404 excludes = [out_dtype]
4405 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004406 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004407
Kevin Cheng550ccc52021-03-03 11:21:43 -08004408 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004409
4410 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004411 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004412 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004413 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004414 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004415 h = 1
4416 w = 1
4417 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004418 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4419 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004420
4421 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004422 choices = [1, 2, 3]
4423 change = rng.choice(choices)
4424 # increment in multiples of stride to not hit non-integer error case
4425 if change in [1, 3]:
4426 h = h + (rng.choice(choices) * stride[0])
4427 if change in [2, 3]:
4428 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004429 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004430
4431 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004432 all_dtypes = [
4433 DType.INT8,
4434 DType.INT16,
4435 DType.INT32,
4436 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004437 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004438 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004439 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004440 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004441 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4442 outputDType = rng.choice(wrong_dtypes)
4443 else:
4444 outputDType = ifm.dtype
4445
4446 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004447
4448 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004449 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004450 # input: N, IC
4451 # filter: OC, IC
4452 # output: N, OC
4453
4454 output_shape = [input.shape[0], filter.shape[0]]
4455
James Ward8b390432022-08-12 20:48:56 +01004456 # Validated in arg_gen (also invalidated for ErrorIf)
4457 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004458
Kevin Cheng550ccc52021-03-03 11:21:43 -08004459 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004460
4461 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004462 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004463 # a: N, H, C
4464 # b: N, C, W
4465 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004466
Kevin Cheng2d60f002021-06-09 14:18:32 -07004467 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004468
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004469 if error_name == ErrorIf.WrongOutputType:
4470 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004471 incorrect_types = (
4472 DType.INT4,
4473 DType.INT8,
4474 DType.INT16,
4475 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004476 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004477 DType.FP16,
4478 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004479 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004480 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004481 incorrect_types = (
4482 DType.INT4,
4483 DType.INT8,
4484 DType.INT16,
4485 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004486 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004487 DType.FP16,
4488 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004489 )
James Ward24dbc422022-10-19 12:20:31 +01004490 elif (
4491 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4492 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004493 incorrect_types = (
4494 DType.INT4,
4495 DType.INT8,
4496 DType.INT16,
4497 DType.INT32,
4498 DType.INT48,
4499 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004500 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004501 elif error_name == ErrorIf.WrongInputType:
4502 # Pick some potentially correct output dtype if input type is incorrect
4503 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004504 else:
James Ward8b390432022-08-12 20:48:56 +01004505 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004506
Kevin Cheng550ccc52021-03-03 11:21:43 -08004507 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004508
4509 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004510 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004511 input1 = a[0]
4512 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004513
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004514 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004515 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004516 if not (
4517 # unable to concat tensors of different ranks
4518 error_name == ErrorIf.ConcatInputRankMismatch
4519 # unable to concat tensors along an invalid axis
4520 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004521 ):
4522 for tensor in remaining_inputs:
4523 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004524
Matthew Haddon01c359d2021-10-15 16:30:48 +01004525 if error_name == ErrorIf.ConcatShapeSumMismatch:
4526 output_shape[axis] += rng.integers(5, 10)
4527
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004528 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004529 all_dtypes = {
4530 DType.INT8,
4531 DType.INT16,
4532 DType.INT32,
4533 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004534 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004535 DType.FP16,
4536 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004537 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004538 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4539 outputDType = rng.choice(wrong_dtypes)
4540 else:
4541 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004542
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004543 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004544
4545 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004546 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004547
4548 output_shape = a.shape.copy()
4549
4550 for i in range(len(output_shape)):
4551 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4552
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004553 if error_name == ErrorIf.PadOutputShapeMismatch:
4554 bad_dim = rng.choice(range(len(output_shape)))
4555 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004556 elif error_name == ErrorIf.RankMismatch:
4557 output_shape = get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004558
Matthew Haddone807aae2021-10-11 18:12:58 +01004559 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004560 all_dtypes = [
4561 DType.INT8,
4562 DType.INT16,
4563 DType.INT32,
4564 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004565 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004566 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004567 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004568 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004569 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4570 outputDType = rng.choice(wrong_dtypes)
4571 else:
4572 outputDType = a.dtype
4573
4574 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004575
4576 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004577 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004578 output_shape = shape.copy()
4579
Matthew Haddone807aae2021-10-11 18:12:58 +01004580 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4581 for i in range(len(output_shape)):
4582 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4583
4584 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004585 all_dtypes = [
4586 DType.INT8,
4587 DType.INT16,
4588 DType.INT32,
4589 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004590 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004591 DType.FP16,
4592 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004593 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004594 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4595 outputDType = rng.choice(wrong_dtypes)
4596 else:
4597 outputDType = a.dtype
4598
4599 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004600
4601 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004602 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004603
Matthew Haddone807aae2021-10-11 18:12:58 +01004604 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004605 all_dtypes = [
4606 DType.INT8,
4607 DType.INT16,
4608 DType.INT32,
4609 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004610 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004611 DType.FP16,
4612 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004613 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004614 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004615 outputDType = rng.choice(wrong_dtypes)
4616 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004617 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004618
Luke Huttona4e48ca2023-02-22 11:53:48 +00004619 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004620 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004621 for index in range(len(output_shape)):
4622 if output_shape[index] <= 2:
4623 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4624 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004625 output_shape[index] = output_shape[index] + rng.choice(
4626 [-2, -1, 1, 2]
4627 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004628 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4629 output_shape = input.shape.copy()
4630 elif error_name == ErrorIf.RankMismatch:
4631 output_shape = get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004632
4633 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004634
4635 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004636 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004637
4638 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004639 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004640
4641 for i in range(len(output_shape)):
4642 output_shape[i] = a.shape[i] * multiples[i]
4643
Luke Huttona4e48ca2023-02-22 11:53:48 +00004644 if error_name == ErrorIf.RankMismatch:
4645 output_shape = get_rank_mismatch_shape(rng, output_shape)
4646
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004647 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004648 all_dtypes = [
4649 DType.INT8,
4650 DType.INT16,
4651 DType.INT32,
4652 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004653 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004654 DType.FP16,
4655 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004656 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004657 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4658 outputDType = rng.choice(wrong_dtypes)
4659 else:
4660 outputDType = a.dtype
4661
4662 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004663
4664 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004665 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004666 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004667
Kevin Cheng550ccc52021-03-03 11:21:43 -08004668 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004669
Luke Huttona4e48ca2023-02-22 11:53:48 +00004670 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004671 for i in range(len(output_shape)):
4672 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004673
Luke Huttona4e48ca2023-02-22 11:53:48 +00004674 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4675 for i in range(len(output_shape)):
4676 output_shape[i] += rng.integers(1, 10)
4677 elif error_name == ErrorIf.RankMismatch:
4678 output_shape = get_rank_mismatch_shape(rng, output_shape)
4679
Matthew Haddone807aae2021-10-11 18:12:58 +01004680 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004681 all_dtypes = [
4682 DType.INT8,
4683 DType.INT16,
4684 DType.INT32,
4685 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004686 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004687 DType.FP16,
4688 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004689 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004690 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4691 outputDType = rng.choice(wrong_dtypes)
4692 else:
4693 outputDType = a.dtype
4694
4695 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004696
4697 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004698 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004699 if error_name != ErrorIf.WrongRank:
4700 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004701 assert len(indices.shape) == 2
4702 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004703
Kevin Cheng77d0f762020-11-24 10:26:32 -08004704 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4705
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004706 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004707 all_dtypes = [
4708 DType.INT8,
4709 DType.INT16,
4710 DType.INT32,
4711 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004712 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004713 DType.FP16,
4714 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004715 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004716 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4717 outputDType = rng.choice(wrong_dtypes)
4718 else:
4719 outputDType = values.dtype
4720
4721 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004722
4723 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004724 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004725 if error_name != ErrorIf.WrongRank:
4726 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004727 assert len(indices.shape) == 2
4728 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004729 assert values_in.shape[0] == indices.shape[0] # N
4730 assert input.shape[1] == indices.shape[1] # W
4731 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004732
4733 output_shape = values_in.shape
4734
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004735 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004736 all_dtypes = [
4737 DType.INT8,
4738 DType.INT16,
4739 DType.INT32,
4740 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004741 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004742 DType.FP16,
4743 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004744 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004745 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4746 outputDType = rng.choice(wrong_dtypes)
4747 else:
4748 outputDType = values_in.dtype
4749
4750 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004751
4752 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004753 def tableOp(ser, rng, input, error_name=None):
4754 # Same shape as the input, dtype dependent on input dtype
4755 if error_name != ErrorIf.WrongInputType:
4756 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004757 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004758 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004759 wrong_dtypes = [
4760 DType.INT8,
4761 DType.INT16,
4762 DType.INT32,
4763 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004764 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004765 DType.FP16,
4766 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004767 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004768 wrong_dtypes.remove(output_dtype)
4769 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004770 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004771
4772 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004773 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004774 serializer,
4775 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004776 input,
4777 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004778 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004779 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004780 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004781 input_dtype,
4782 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004783 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004784 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004785 # Calculate OH, OW
4786 scale_y_n = scale[0]
4787 scale_y_d = scale[1]
4788 scale_x_n = scale[2]
4789 scale_x_d = scale[3]
4790 if error_name == ErrorIf.ScaleSmallerEqualZero:
4791 scale_y_n = max(scale_y_n, 1)
4792 scale_y_d = max(scale_y_d, 1)
4793 scale_x_n = max(scale_x_n, 1)
4794 scale_x_d = max(scale_x_d, 1)
4795
4796 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4797 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4798
4799 if error_name is not None:
4800 # Make sure the output tensor is valid, which can occur when
4801 # scale, offset or border have been changed for ERROR_IFs
4802 oh = max(oh, 1)
4803 ow = max(ow, 1)
4804 if error_name != ErrorIf.MaxDimExceeded:
4805 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4806 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4807
4808 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4809 choices = [1, 2, 3]
4810 change = rng.choice(choices)
4811 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4812 if change in [1, 3]:
4813 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4814 oh -= scale_y_d
4815 assert oh > 0 # Should have been caught in agResize
4816 else:
4817 oh += scale_y_d
4818 if change in [2, 3]:
4819 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4820 ow -= scale_x_d
4821 assert ow > 0 # Should have been caught in agResize
4822 else:
4823 ow += scale_x_d
4824
Matthew Haddon848efb42021-09-09 12:30:53 +01004825 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004826 output_dims = [
4827 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004828 oh,
4829 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004830 input.shape[0],
4831 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004832 elif error_name == ErrorIf.BatchMismatch:
4833 output_dims = [
4834 input.shape[0] + rng.integers(1, 10),
4835 oh,
4836 ow,
4837 input.shape[3],
4838 ]
4839 elif error_name == ErrorIf.ChannelMismatch:
4840 output_dims = [
4841 input.shape[0],
4842 oh,
4843 ow,
4844 input.shape[3] + rng.integers(1, 10),
4845 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004846 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004847 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004848
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004849 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004850
4851 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004852 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004853 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004854
4855 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004856 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004857 if error_name == ErrorIf.ConvOutputShapeMismatch:
4858 choices = [1, 2, 3]
4859 change = rng.choice(choices)
4860 if change in [1, 3]:
4861 output_shape[1] = output_shape[1] + rng.choice(choices)
4862 if change in [2, 3]:
4863 output_shape[2] = output_shape[2] + rng.choice(choices)
4864
James Ward8b390432022-08-12 20:48:56 +01004865 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004866 # Pick some potentially correct output dtype if input type is incorrect
4867 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004868 else:
James Ward8b390432022-08-12 20:48:56 +01004869 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004870
4871 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004872 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004873 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004874 else:
4875 excludes = [out_dtype]
4876 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004877 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004878
Kevin Cheng550ccc52021-03-03 11:21:43 -08004879 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00004880
4881 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00004882 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
4883 outputs = []
4884
4885 assert ifm1.dtype == ifm2.dtype
4886 input_dtype = ifm1.dtype
4887
4888 if error_name != ErrorIf.FFTInputShapeMismatch:
4889 assert ifm1.shape == ifm2.shape
4890
4891 input_shape = ifm1.shape
4892 if error_name != ErrorIf.WrongRank:
4893 assert len(input_shape) == 3
4894
4895 output_shape = input_shape.copy()
4896 output_dtype = input_dtype
4897
4898 if error_name == ErrorIf.WrongOutputType:
4899 excludes = [DType.FP32]
4900 wrong_dtypes = list(usableDTypes(excludes=excludes))
4901 output_dtype = rng.choice(wrong_dtypes)
4902 elif error_name == ErrorIf.BatchMismatch:
4903 output_shape[0] += rng.integers(1, 10)
4904 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4905 modify_dim = rng.choice([1, 2])
4906 output_shape[modify_dim] += rng.integers(1, 10)
4907
4908 outputs.append(serializer.addOutput(output_shape, output_dtype))
4909 outputs.append(serializer.addOutput(output_shape, output_dtype))
4910 return outputs
4911
4912 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00004913 def rfft2dOp(serializer, rng, value, error_name=None):
4914 outputs = []
4915
4916 input_shape = value.shape
4917 if error_name != ErrorIf.WrongRank:
4918 assert len(input_shape) == 3
4919
4920 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
4921
4922 output_dtype = value.dtype
4923 if error_name == ErrorIf.WrongOutputType:
4924 excludes = [DType.FP32]
4925 wrong_dtypes = list(usableDTypes(excludes=excludes))
4926 output_dtype = rng.choice(wrong_dtypes)
4927 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00004928 output_shape[0] += rng.integers(1, 10)
4929 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4930 modify_dim = rng.choice([1, 2])
4931 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00004932
4933 outputs.append(serializer.addOutput(output_shape, output_dtype))
4934 outputs.append(serializer.addOutput(output_shape, output_dtype))
4935 return outputs