blob: c816c6bb67fe5c413bc60c14acf26e41292356b2 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Luke Huttona4e48ca2023-02-22 11:53:48 +000017from generator.tosa_utils import get_rank_mismatch_shape
Jeremy Johnson05c711e2022-12-12 18:00:41 +000018from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010019from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010021from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010027 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000028 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010029 TOSA_TENSOR_MAX_RANK = 6
30
Eric Kunzee5e26762020-10-13 16:11:07 -070031 def __init__(self, args):
32 self.args = args
33 self.basePath = args.output_dir
34 self.random_seed = args.random_seed
35 self.ser = None
36 self.rng = np.random.default_rng(self.random_seed)
37 self.createDynamicOpLists()
38 self.initOpListDefaults()
39 self.quantGen = TosaQuantGen()
40 # Force makeShape to do a specific starting shape
41 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010042 # Work out floating point range
43 self.random_fp_low = min(args.tensor_fp_value_range)
44 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070045
46 def createSerializer(self, opName, testPath):
47 self.testPath = os.path.join(opName, testPath)
48
49 fullPath = os.path.join(self.basePath, self.testPath)
50 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010051 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070052
53 def getSerializer(self):
54 return self.ser
55
56 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080057 with open(
58 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
59 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070060 fd.write(self.ser.serialize())
61
Kevin Cheng550ccc52021-03-03 11:21:43 -080062 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
63 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070064
Matthew Haddon74567092021-07-16 15:38:20 +010065 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000066 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010067 seed = self.random_seed + 1
68 self.rng = np.random.default_rng(seed)
69
Eric Kunzee5e26762020-10-13 16:11:07 -070070 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070071 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070072 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070073 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070074 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070075 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070076 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010077 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
78 elif dtype == DType.UINT8:
79 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070080 elif dtype == DType.INT16:
81 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010082 elif dtype == DType.UINT16:
83 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070084 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080085 return np.int32(
86 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
87 )
Eric Kunzee5e26762020-10-13 16:11:07 -070088 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080089 return np.int64(
90 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
91 )
James Ward8b390432022-08-12 20:48:56 +010092 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010093 return np.float16(
94 self.rng.uniform(
95 low=self.random_fp_low, high=self.random_fp_high, size=shape
96 )
97 )
James Ward24dbc422022-10-19 12:20:31 +010098 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010099 f32_tensor = np.float32(
100 self.rng.uniform(
101 low=self.random_fp_low, high=self.random_fp_high, size=shape
102 )
103 )
James Ward24dbc422022-10-19 12:20:31 +0100104 # Floor the last 16 bits of each f32 value
105 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100106 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100107 return np.float32(
108 self.rng.uniform(
109 low=self.random_fp_low, high=self.random_fp_high, size=shape
110 )
111 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700112 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800113 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700114
Kevin Cheng989cb052021-04-28 16:29:44 -0700115 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700116 placeholders = []
117
Kevin Cheng989cb052021-04-28 16:29:44 -0700118 assert len(shape_list) == len(dtype_list)
119
120 for idx, shape in enumerate(shape_list):
121 arr = self.getRandTensor(shape, dtype_list[idx])
122 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700123
124 return placeholders
125
Kevin Cheng989cb052021-04-28 16:29:44 -0700126 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 consts = []
128
Kevin Cheng989cb052021-04-28 16:29:44 -0700129 assert len(shape_list) == len(dtype_list)
130
131 for idx, shape in enumerate(shape_list):
132 arr = self.getRandTensor(shape, dtype_list[idx])
133 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700134
135 return consts
136
137 def makeShape(self, rank):
138 if self.targetted_shape:
139 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 return np.int32(
141 self.rng.integers(
142 low=self.args.tensor_shape_range[0],
143 high=self.args.tensor_shape_range[1],
144 size=rank,
145 )
146 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700147
148 def setTargetShape(self, shape):
149 self.targetted_shape = shape
150
151 def randInt(self, low=0, high=256):
152 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
153
154 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100155 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100156 return np.float32(
157 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
158 )
James Ward8b390432022-08-12 20:48:56 +0100159 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100160 return np.float16(
161 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
162 )
James Ward24dbc422022-10-19 12:20:31 +0100163 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100164 rand_f32 = np.float32(
165 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
166 )
James Ward24dbc422022-10-19 12:20:31 +0100167 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700168 elif dtype == DType.BOOL:
169 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700170 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700172 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100174 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700175 elif dtype == DType.INT16:
176 low, high = (-32768, 32768)
177 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800178 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800180 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 # Special size
182 return np.int64(self.rng.integers(low, high, size=1))[0]
183 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800184 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700185
186 return np.int32(self.rng.integers(low, high, size=1))[0]
187
188 def shapeStr(self, shape):
189
190 sStr = []
191 # Convert to strings
192 for i in shape:
193 sStr.append(str(i))
194
Kevin Cheng550ccc52021-03-03 11:21:43 -0800195 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700196
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100197 def typeStr(self, dtype):
198 if isinstance(dtype, list) or isinstance(dtype, tuple):
199 assert len(dtype) >= 2
200 strs = [self.typeStr(t) for t in dtype]
201 # Limit types to the first 2 as the 3rd is the accumulator
202 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100204 if dtype in DTYPE_ATTRIBUTES:
205 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700206 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100207 raise Exception(
208 "Unknown dtype, cannot convert to string: {}".format(dtype)
209 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100211 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100212 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100213 if dtype in DTYPE_ATTRIBUTES:
214 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700215 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100216 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Luke Hutton57287132023-02-06 14:54:18 +0000218 def constrictBatchSize(self, shape):
219 # Limit the batch size unless an explicit target shape set
220 if self.args.max_batch_size and not self.args.target_shapes:
221 shape[0] = min(shape[0], self.args.max_batch_size)
222 return shape
223
James Ward30124a82023-02-02 14:56:33 +0000224 def makeDimension(self):
225 return self.randInt(
226 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
227 )
228
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 # Argument generators
230 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
231 # Where the string descriptor is used to generate the test name and
232 # The build_fcn_arg_list is expanded and passed to the operator test
233 # build function
234
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100235 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
236 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
237
Matthew Haddon848efb42021-09-09 12:30:53 +0100238 # build_placeholder returns an int, ABS/other ops does not
239 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000240 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100241 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000242 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000243 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100244 return result_tens
245
246 # Ensure new output type has correct qinfo
247 if error_name == ErrorIf.WrongOutputType:
248 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000249 qinfo = [
250 TosaQuantGen.getZeroPoint(self, a.dtype),
251 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
252 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100253
254 # Invalidate Input/Output list for error if checks.
255 input_list = [a.name]
256 output_list = [result_tens.name]
257 pCount, cCount = op["operands"]
258 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000259 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
260 self, error_name, input_list, output_list
261 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100262
Les Bell729b0352021-11-24 10:28:21 +0000263 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100264 self.ser,
265 validator_fcns,
266 error_name,
267 op=op,
268 input_dtype=a.dtype,
269 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000270 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000271 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100272 input_list=input_list,
273 output_list=output_list,
274 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000275 ):
276 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100277
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000278 attr = None
279 if op["op"] == Op.NEGATE:
280 attr = ts.TosaSerializerAttribute()
281 attr.NegateAttribute(qinfo[0], qinfo[1])
282
283 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700284 return result_tens
285
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100286 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000287 result_tens = OutputShaper.binaryBroadcastOp(
288 self.ser, self.rng, a, b, error_name
289 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100290
291 # Invalidate Input/Output list for error if checks.
292 input_list = [a.name, b.name]
293 output_list = [result_tens.name]
294 pCount, cCount = op["operands"]
295 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000296 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
297 self, error_name, input_list, output_list
298 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100299
Les Bell729b0352021-11-24 10:28:21 +0000300 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100301 self.ser,
302 validator_fcns,
303 error_name,
304 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000305 input1=a,
306 input2=b,
307 input_dtype=a.dtype,
308 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000309 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100310 input_list=input_list,
311 output_list=output_list,
312 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000313 ):
314 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100315
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000316 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700317 return result_tens
318
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100319 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700320 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000321 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700322 return result_tens
323
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000324 def build_arithmetic_right_shift(
325 self, op, a, b, round, validator_fcns=None, error_name=None
326 ):
327 result_tens = OutputShaper.binaryBroadcastOp(
328 self.ser, self.rng, a, b, error_name
329 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100330
331 # Invalidate Input/Output list for error if checks.
332 input_list = [a.name, b.name]
333 output_list = [result_tens.name]
334 pCount, cCount = op["operands"]
335 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000336 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
337 self, error_name, input_list, output_list
338 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100339
Les Bell729b0352021-11-24 10:28:21 +0000340 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100341 self.ser,
342 validator_fcns,
343 error_name,
344 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000345 input1=a,
346 input2=b,
347 input_dtype=a.dtype,
348 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000349 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100350 input_list=input_list,
351 output_list=output_list,
352 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000353 ):
354 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800355
356 attr = ts.TosaSerializerAttribute()
357 attr.ArithmeticRightShiftAttribute(round)
358
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000359 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800360 return result_tens
361
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100362 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000363 result_tens = OutputShaper.binaryBroadcastOp(
364 self.ser, self.rng, a, b, error_name
365 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700366
367 # Special for multiply:
368 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100369 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700370 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100371 if error_name == ErrorIf.WrongOutputType:
372 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
373 outputDType = self.rng.choice(all_dtypes)
374 result_tens.setDtype(outputDType)
375
376 # Invalidate Input/Output list for error if checks.
377 input_list = [a.name, b.name]
378 output_list = [result_tens.name]
379 pCount, cCount = op["operands"]
380 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000381 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
382 self, error_name, input_list, output_list
383 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100384
Les Bell729b0352021-11-24 10:28:21 +0000385 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100386 self.ser,
387 validator_fcns,
388 error_name,
389 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000390 input1=a,
391 input2=b,
392 input_dtype=a.dtype,
393 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000394 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100395 input_list=input_list,
396 output_list=output_list,
397 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000398 ):
399 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700400
Kevin Chengaee1fac2020-11-11 13:54:06 -0800401 attr = ts.TosaSerializerAttribute()
402 attr.MulAttribute(shift)
403
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000404 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700405 return result_tens
406
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100407 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
408 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700409
Kevin Chengfe392ce2021-10-18 21:51:55 +0000410 attr = ts.TosaSerializerAttribute()
411 attr.TableAttribute(table)
412
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100413 # Invalidate Input/Output list for error if checks.
414 input_list = [a.name]
415 output_list = [result_tens.name]
416 pCount, cCount = op["operands"]
417 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000418 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
419 self, error_name, input_list, output_list
420 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100421
Les Bell729b0352021-11-24 10:28:21 +0000422 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100423 self.ser,
424 validator_fcns,
425 error_name,
426 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000427 input_shape=a.shape,
428 input_dtype=a.dtype,
429 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000430 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100431 input_list=input_list,
432 output_list=output_list,
433 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000434 ):
435 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000437 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700438
439 return result_tens
440
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100441 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
442 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
443
444 # Invalidate Input/Output list for error if checks.
445 input_list = [cond.name, a.name, b.name]
446 output_list = [result_tens.name]
447 pCount, cCount = op["operands"]
448 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000449 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
450 self, error_name, input_list, output_list
451 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452
Les Bell729b0352021-11-24 10:28:21 +0000453 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100454 self.ser,
455 validator_fcns,
456 error_name,
457 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000458 input1=cond,
459 input2=a,
460 input3=b,
461 input_shape=a.shape,
462 input_dtype=a.dtype,
463 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000464 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465 input_list=input_list,
466 output_list=output_list,
467 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000468 ):
469 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100470
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000471 self.ser.addOperator(
472 op["op"],
473 input_list,
474 output_list,
475 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700476 return result_tens
477
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000479 result_tens = OutputShaper.binaryComparisonOp(
480 self.ser, self.rng, a, b, error_name
481 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482
483 # Invalidate Input/Output list for error if checks.
484 input_list = [a.name, b.name]
485 output_list = [result_tens.name]
486 pCount, cCount = op["operands"]
487 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000488 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
489 self, error_name, input_list, output_list
490 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491
Les Bell729b0352021-11-24 10:28:21 +0000492 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100493 self.ser,
494 validator_fcns,
495 error_name,
496 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 input1=a,
498 input2=b,
499 input_shape=a.shape,
500 input_dtype=a.dtype,
501 output_shape=result_tens.shape,
502 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000503 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100504 input_list=input_list,
505 output_list=output_list,
506 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000507 ):
508 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100509
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000510 self.ser.addOperator(
511 op["op"],
512 input_list,
513 output_list,
514 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700515 return result_tens
516
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100517 def build_argmax(self, op, a, axis, validator_fcns, error_name):
518 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
519
520 # Invalidate Input/Output list for error if checks.
521 input_list = [a.name]
522 output_list = [result_tens.name]
523 pCount, cCount = op["operands"]
524 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000525 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
526 self, error_name, input_list, output_list
527 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100528
Les Bell729b0352021-11-24 10:28:21 +0000529 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100530 self.ser,
531 validator_fcns,
532 error_name,
533 op=op,
534 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000535 input_shape=a.shape,
536 input_dtype=a.dtype,
537 output_shape=result_tens.shape,
538 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000539 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100540 input_list=input_list,
541 output_list=output_list,
542 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000543 ):
544 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700545
546 attr = ts.TosaSerializerAttribute()
547 attr.AxisAttribute(axis)
548
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000549 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700550 return result_tens
551
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000552 def build_pool2d(
553 self,
554 op,
555 input,
James Ward8b390432022-08-12 20:48:56 +0100556 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000557 stride,
558 pad,
559 kernel,
560 validator_fcns=None,
561 error_name=None,
562 qinfo=None,
563 ):
564 result_tens = OutputShaper.pool2dOp(
565 self.ser, self.rng, input, kernel, stride, pad, error_name
566 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100567
568 # Ensure new output type has correct qinfo
569 if error_name == ErrorIf.WrongInputType:
570 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000571 qinfo = [
572 TosaQuantGen.getZeroPoint(self, input.dtype),
573 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
574 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100575
576 # Invalidate Input/Output list for error if checks.
577 input_list = [input.name]
578 output_list = [result_tens.name]
579 pCount, cCount = op["operands"]
580 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000581 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
582 self, error_name, input_list, output_list
583 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100584
Les Bell729b0352021-11-24 10:28:21 +0000585 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100586 self.ser,
587 validator_fcns,
588 error_name,
589 op=op,
590 input_shape=input.shape,
591 input_dtype=input.dtype,
592 output_shape=result_tens.shape,
593 output_dtype=result_tens.dtype,
594 kernel=kernel,
595 stride=stride,
596 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000597 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000598 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100599 input_list=input_list,
600 output_list=output_list,
601 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000602 ):
603 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700604
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000605 if qinfo is None:
606 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700607
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000608 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100609 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000610
611 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700612 return result_tens
613
James Ward8b390432022-08-12 20:48:56 +0100614 def build_maxpool2d(
615 self,
616 op,
617 input,
618 stride,
619 pad,
620 kernel,
621 validator_fcns=None,
622 error_name=None,
623 qinfo=None,
624 ):
625 # Same as build_pool2d but manually sets accum_dtype value
626 # (maxpool has no accum_dtype)
627 return self.build_pool2d(
628 op,
629 input,
630 DType.UNKNOWN,
631 stride,
632 pad,
633 kernel,
634 validator_fcns,
635 error_name,
636 qinfo,
637 )
638
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000639 def build_conv2d(
640 self,
641 op,
642 ifm,
643 filter,
644 bias,
James Ward8b390432022-08-12 20:48:56 +0100645 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000646 strides,
647 padding,
648 dilations,
649 validator_fcns=None,
650 error_name=None,
651 qinfo=None,
652 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800653 assert len(padding) == 4
654 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100655 self.ser,
656 self.rng,
657 ifm,
658 filter,
659 accum_dtype,
660 strides,
661 padding,
662 dilations,
663 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000664 )
665
666 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
668 DType.INT8,
669 DType.UINT8,
670 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000671 qinfo = [
672 TosaQuantGen.getZeroPoint(self, ifm.dtype),
673 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
674 ]
Les Bell0e027d42021-11-09 14:42:14 +0000675
676 # Invalidate Input/Output list for error_if checks.
677 input_list = [ifm.name, filter.name, bias.name]
678 output_list = [result_tens.name]
679 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000680 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
681 self, error_name, input_list, output_list
682 )
Les Bell0e027d42021-11-09 14:42:14 +0000683
Les Bell729b0352021-11-24 10:28:21 +0000684 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000685 self.ser,
686 validator_fcns,
687 error_name,
688 op=op,
689 input_dtype=ifm.dtype,
690 weight_dtype=filter.dtype,
691 output_dtype=result_tens.dtype,
692 qinfo=qinfo,
693 input_list=input_list,
694 num_operands=num_operands,
695 output_list=output_list,
696 pad=padding,
697 stride=strides,
698 dilation=dilations,
699 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100700 weight_shape=filter.shape,
701 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000702 ):
703 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
705 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000706 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700707
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000708 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700709 return result_tens
710
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000711 def build_conv3d(
712 self,
713 op,
714 ifm,
715 filter,
716 bias,
James Ward8b390432022-08-12 20:48:56 +0100717 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000718 strides,
719 padding,
720 dilations,
721 validator_fcns=None,
722 error_name=None,
723 qinfo=None,
724 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700725 assert len(padding) == 6
726 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100727 self.ser,
728 self.rng,
729 ifm,
730 filter,
731 accum_dtype,
732 strides,
733 padding,
734 dilations,
735 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000736 )
737
738 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000739 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
740 DType.INT8,
741 DType.UINT8,
742 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000743 qinfo = [
744 TosaQuantGen.getZeroPoint(self, ifm.dtype),
745 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
746 ]
Les Bell0e027d42021-11-09 14:42:14 +0000747
748 # Invalidate Input/Output list for error_if checks.
749 input_list = [ifm.name, filter.name, bias.name]
750 output_list = [result_tens.name]
751 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000752 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
753 self, error_name, input_list, output_list
754 )
Les Bell0e027d42021-11-09 14:42:14 +0000755
Les Bell729b0352021-11-24 10:28:21 +0000756 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000757 self.ser,
758 validator_fcns,
759 error_name,
760 op=op,
761 input_dtype=ifm.dtype,
762 weight_dtype=filter.dtype,
763 output_dtype=result_tens.dtype,
764 qinfo=qinfo,
765 input_list=input_list,
766 num_operands=num_operands,
767 output_list=output_list,
768 pad=padding,
769 stride=strides,
770 dilation=dilations,
771 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100772 weight_shape=filter.shape,
773 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000774 ):
775 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700776
777 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000778 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700779
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000780 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700781 return result_tens
782
Kevin Cheng550ccc52021-03-03 11:21:43 -0800783 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000784 self,
785 op,
786 ifm,
787 filter,
788 bias,
James Ward8b390432022-08-12 20:48:56 +0100789 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700791 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000792 output_shape,
793 validator_fcns=None,
794 error_name=None,
795 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800796 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700797 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000798 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100799 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000800 )
Les Bell0e027d42021-11-09 14:42:14 +0000801
802 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
804 DType.INT8,
805 DType.UINT8,
806 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000807 qinfo = [
808 TosaQuantGen.getZeroPoint(self, ifm.dtype),
809 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
810 ]
Les Bell0e027d42021-11-09 14:42:14 +0000811
812 # Invalidate Input/Output list for error_if checks.
813 input_list = [ifm.name, filter.name, bias.name]
814 output_list = [result_tens.name]
815 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000816 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
817 self, error_name, input_list, output_list
818 )
Les Bell0e027d42021-11-09 14:42:14 +0000819
Les Bell729b0352021-11-24 10:28:21 +0000820 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000821 self.ser,
822 validator_fcns,
823 error_name,
824 op=op,
825 input_dtype=ifm.dtype,
826 weight_dtype=filter.dtype,
827 output_dtype=result_tens.dtype,
828 qinfo=qinfo,
829 input_list=input_list,
830 num_operands=num_operands,
831 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700832 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000833 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000834 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100835 weight_shape=filter.shape,
836 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000837 ):
838 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700839
840 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000841 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700842
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000843 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700844 return result_tens
845
Kevin Cheng550ccc52021-03-03 11:21:43 -0800846 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000847 self,
848 op,
849 ifm,
850 filter,
851 bias,
James Ward8b390432022-08-12 20:48:56 +0100852 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000853 strides,
854 padding,
855 dilations,
856 validator_fcns=None,
857 error_name=None,
858 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800859 ):
860 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100861 self.ser,
862 self.rng,
863 ifm,
864 filter,
865 accum_dtype,
866 strides,
867 padding,
868 dilations,
869 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000870 )
871
872 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000873 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
874 DType.INT8,
875 DType.UINT8,
876 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000877 qinfo = [
878 TosaQuantGen.getZeroPoint(self, ifm.dtype),
879 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
880 ]
Les Bell0e027d42021-11-09 14:42:14 +0000881
882 # Invalidate Input/Output list for error_if checks.
883 input_list = [ifm.name, filter.name, bias.name]
884 output_list = [result_tens.name]
885 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000886 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
887 self, error_name, input_list, output_list
888 )
Les Bell0e027d42021-11-09 14:42:14 +0000889
Les Bell729b0352021-11-24 10:28:21 +0000890 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000891 self.ser,
892 validator_fcns,
893 error_name,
894 op=op,
895 input_dtype=ifm.dtype,
896 weight_dtype=filter.dtype,
897 output_dtype=result_tens.dtype,
898 qinfo=qinfo,
899 input_list=input_list,
900 num_operands=num_operands,
901 output_list=output_list,
902 pad=padding,
903 stride=strides,
904 dilation=dilations,
905 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100906 weight_shape=filter.shape,
907 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000908 ):
909 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700910
911 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000912 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700913
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000914 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700915 return result_tens
916
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000917 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100918 self,
919 op,
920 ifm,
921 filter,
922 bias,
923 accum_dtype,
924 validator_fcns=None,
925 error_name=None,
926 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000927 ):
928 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100929 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000930 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100931
932 # Invalidate Input/Output list for error if checks.
933 input_list = [ifm.name, filter.name, bias.name]
934 output_list = [result_tens.name]
935 pCount, cCount = op["operands"]
936 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000937 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
938 self, error_name, input_list, output_list
939 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100940
Les Bell729b0352021-11-24 10:28:21 +0000941 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100942 self.ser,
943 validator_fcns,
944 error_name,
945 op=op,
946 input_shape=ifm.shape,
947 input_dtype=ifm.dtype,
948 weight_dtype=filter.dtype,
949 output_shape=result_tens.shape,
950 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000951 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000952 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100953 input_list=input_list,
954 output_list=output_list,
955 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100956 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000957 ):
958 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700959
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000960 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000961 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000962
963 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700964 return result_tens
965
James Ward8b390432022-08-12 20:48:56 +0100966 def build_matmul(
967 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
968 ):
969 result_tens = OutputShaper.matmulOp(
970 self.ser, self.rng, a, b, accum_dtype, error_name
971 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100972
973 # Invalidate Input/Output list for error if checks.
974 input_list = [a.name, b.name]
975 output_list = [result_tens.name]
976 pCount, cCount = op["operands"]
977 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000978 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
979 self, error_name, input_list, output_list
980 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100981
Les Bell729b0352021-11-24 10:28:21 +0000982 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100983 self.ser,
984 validator_fcns,
985 error_name,
986 op=op,
987 input_shape=a.shape,
988 input_dtype=a.dtype,
989 input2_shape=b.shape,
990 input2_dtype=b.dtype,
991 output_shape=result_tens.shape,
992 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000993 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000994 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100995 input_list=input_list,
996 output_list=output_list,
997 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100998 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000999 ):
1000 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001001
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001002 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001003 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001004
1005 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001006 return result_tens
1007
Matthew Haddond6ce7252021-09-29 15:35:44 +01001008 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1009 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1010
1011 # Invalidate Input/Output list for error if checks.
1012 input_list = [a.name]
1013 output_list = [result_tens.name]
1014 pCount, cCount = op["operands"]
1015 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001016 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1017 self, error_name, input_list, output_list
1018 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001019
Les Bell729b0352021-11-24 10:28:21 +00001020 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001021 self.ser,
1022 validator_fcns,
1023 error_name,
1024 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001025 axis=axis,
1026 input_shape=a.shape,
1027 output_shape=result_tens.shape,
1028 input_dtype=a.dtype,
1029 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001030 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001031 input_list=input_list,
1032 output_list=output_list,
1033 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001034 ):
1035 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001036
1037 attr = ts.TosaSerializerAttribute()
1038 attr.AxisAttribute(axis)
1039
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001040 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001041 return result_tens
1042
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001043 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1044 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001045
Jeremy Johnson18e26662021-07-22 16:15:29 +01001046 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001047
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001048 if error_name == ErrorIf.MaxSmallerMin:
1049 # Make sure the numbers are different to invoke this error
1050 while v[0] == v[1]:
1051 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1052 max_val = min(v)
1053 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001054 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001055 max_val = max(v)
1056 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001057
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001058 # Invalidate Input/Output list for error if checks.
1059 input_list = [a.name]
1060 output_list = [result_tens.name]
1061 pCount, cCount = op["operands"]
1062 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001063 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1064 self, error_name, input_list, output_list
1065 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001066
Les Bell729b0352021-11-24 10:28:21 +00001067 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001068 self.ser,
1069 validator_fcns,
1070 error_name,
1071 op=op,
1072 max_val=max_val,
1073 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001074 input_shape=a.shape,
1075 output_shape=result_tens.shape,
1076 input_dtype=a.dtype,
1077 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001078 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001079 input_list=input_list,
1080 output_list=output_list,
1081 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001082 ):
1083 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001084
1085 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001086 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1087 if a.dtype == DType.FP16:
1088 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1089 min_val = min_val.astype(np.float32)
1090 max_val = max_val.astype(np.float32)
1091
1092 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001093 else:
James Ward34071252022-12-07 15:48:47 +00001094 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001095
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001096 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001097 return result_tens
1098
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001099 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1100 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001101 attr = ts.TosaSerializerAttribute()
1102
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001103 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001104
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001105 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001106 return result_tens
1107
1108 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001109 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1110 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001112 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001113 return result_tens
1114
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001115 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1116 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1117
1118 # Invalidate Input/Output list for error if checks.
1119 input_list = [a.name]
1120 output_list = [result_tens.name]
1121 pCount, cCount = op["operands"]
1122 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001123 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1124 self, error_name, input_list, output_list
1125 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001126
Les Bell729b0352021-11-24 10:28:21 +00001127 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001128 self.ser,
1129 validator_fcns,
1130 error_name,
1131 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001132 input_shape=a.shape,
1133 output_shape=result_tens.shape,
1134 input_dtype=a.dtype,
1135 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001136 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001137 input_list=input_list,
1138 output_list=output_list,
1139 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001140 ):
1141 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001142
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001144 return result_tens
1145
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001146 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1147 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1148
1149 # Invalidate Input/Output list for error if checks.
1150 input_list = [a.name]
1151 output_list = [result_tens.name]
1152 pCount, cCount = op["operands"]
1153 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001154 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1155 self, error_name, input_list, output_list
1156 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001157
Les Bell729b0352021-11-24 10:28:21 +00001158 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001159 self.ser,
1160 validator_fcns,
1161 error_name,
1162 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001163 input_shape=a.shape,
1164 output_shape=result_tens.shape,
1165 input_dtype=a.dtype,
1166 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001167 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001168 input_list=input_list,
1169 output_list=output_list,
1170 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001171 ):
1172 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001173
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001174 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001175 return result_tens
1176
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001177 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1178 if error_name != ErrorIf.WrongInputType:
1179 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001180
1181 # To store variable length list of input tensors we need to store axis along with it
1182 axis = a[-1]
1183 a = a[:-1]
1184
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001185 result_tens = OutputShaper.concatOp(
1186 self.ser, self.rng, axis, *a, error_name=error_name
1187 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001188
Matthew Haddon818ab902021-07-27 09:12:49 +01001189 input_tensor_names = []
1190 for tensor in a:
1191 input_tensor_names.append(tensor.name)
1192
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001193 # Invalidate Input/Output list for error if checks.
1194 input_list = input_tensor_names
1195 output_list = [result_tens.name]
1196 pCount, cCount = op["operands"]
1197 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001198 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1199 self, error_name, input_list, output_list
1200 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001201
Les Bell729b0352021-11-24 10:28:21 +00001202 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001203 self.ser,
1204 validator_fcns,
1205 error_name,
1206 op=op,
1207 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001208 input_shape=a[0].shape,
1209 output_shape=result_tens.shape,
1210 input_dtype=a[0].dtype,
1211 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001212 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001213 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001214 input_list=input_list,
1215 output_list=output_list,
1216 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001217 ):
1218 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001219
1220 attr = ts.TosaSerializerAttribute()
1221 attr.AxisAttribute(axis)
1222
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001223 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001224 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001225
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001226 def build_pad(
1227 self,
1228 op,
1229 a,
1230 padding,
1231 pad_const_int,
1232 pad_const_float,
1233 validator_fcns=None,
1234 error_name=None,
1235 qinfo=None,
1236 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001237 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001238
Kevin Chengfe392ce2021-10-18 21:51:55 +00001239 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001240 attr.PadAttribute(
1241 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1242 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001243
Matthew Haddone807aae2021-10-11 18:12:58 +01001244 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001245 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001246 output_list = [result_tens.name]
1247 pCount, cCount = op["operands"]
1248 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001249 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1250 self, error_name, input_list, output_list
1251 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001252
Les Bell729b0352021-11-24 10:28:21 +00001253 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001254 self.ser,
1255 validator_fcns,
1256 error_name,
1257 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001258 input_shape=a.shape,
1259 output_shape=result_tens.shape,
1260 input_dtype=a.dtype,
1261 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001262 pad=padding,
1263 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001264 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001265 input_list=input_list,
1266 output_list=output_list,
1267 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001268 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001269 ):
1270 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001271
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001272 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001273 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001274
Matthew Haddone807aae2021-10-11 18:12:58 +01001275 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001276 result_tens = OutputShaper.reshapeOp(
1277 self.ser, self.rng, a, newShape, error_name
1278 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001279
1280 # Invalidate Input/Output list for error if checks.
1281 input_list = [a.name]
1282 output_list = [result_tens.name]
1283 pCount, cCount = op["operands"]
1284 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001285 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1286 self, error_name, input_list, output_list
1287 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001288
Les Bell729b0352021-11-24 10:28:21 +00001289 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001290 self.ser,
1291 validator_fcns,
1292 error_name,
1293 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001294 input_shape=a.shape,
1295 output_shape=result_tens.shape,
1296 input_dtype=a.dtype,
1297 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001298 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001299 input_list=input_list,
1300 output_list=output_list,
1301 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001302 ):
1303 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001304
1305 attr = ts.TosaSerializerAttribute()
1306 attr.ReshapeAttribute(newShape)
1307
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001308 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001309 return result_tens
1310
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001311 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1312 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1313
1314 # Invalidate Input/Output list for error if checks.
1315 input_list = [a.name]
1316 output_list = [result_tens.name]
1317 pCount, cCount = op["operands"]
1318 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001319 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1320 self, error_name, input_list, output_list
1321 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322
Les Bell729b0352021-11-24 10:28:21 +00001323 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001324 self.ser,
1325 validator_fcns,
1326 error_name,
1327 op=op,
1328 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001329 input_shape=a.shape,
1330 output_shape=result_tens.shape,
1331 input_dtype=a.dtype,
1332 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001333 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001334 input_list=input_list,
1335 output_list=output_list,
1336 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001337 ):
1338 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001339
1340 attr = ts.TosaSerializerAttribute()
1341 attr.AxisAttribute(axis)
1342
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001343 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001344 return result_tens
1345
Matthew Haddone807aae2021-10-11 18:12:58 +01001346 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1347 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001348
Kevin Chengfe392ce2021-10-18 21:51:55 +00001349 attr = ts.TosaSerializerAttribute()
1350 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001351
Matthew Haddone807aae2021-10-11 18:12:58 +01001352 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001353 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001354 output_list = [result_tens.name]
1355 pCount, cCount = op["operands"]
1356 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001357 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1358 self, error_name, input_list, output_list
1359 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001360
Les Bell729b0352021-11-24 10:28:21 +00001361 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001362 self.ser,
1363 validator_fcns,
1364 error_name,
1365 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001366 input_shape=a.shape,
1367 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001368 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 input_dtype=a.dtype,
1370 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001371 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001372 input_list=input_list,
1373 output_list=output_list,
1374 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001375 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001376 ):
1377 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001378
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001379 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001380 return result_tens
1381
Matthew Haddone807aae2021-10-11 18:12:58 +01001382 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001383 result_tens = OutputShaper.sliceOp(
1384 self.ser, self.rng, a, start, size, error_name
1385 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001386
1387 # Invalidate Input/Output list for error if checks.
1388 input_list = [a.name]
1389 output_list = [result_tens.name]
1390 pCount, cCount = op["operands"]
1391 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001392 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1393 self, error_name, input_list, output_list
1394 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001395
Les Bell729b0352021-11-24 10:28:21 +00001396 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001397 self.ser,
1398 validator_fcns,
1399 error_name,
1400 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001401 input_shape=a.shape,
1402 output_shape=result_tens.shape,
1403 input_dtype=a.dtype,
1404 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001405 start=start,
1406 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001407 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001408 input_list=input_list,
1409 output_list=output_list,
1410 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001411 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001412 ):
1413 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001414
1415 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001416 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001417
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001418 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001419 return result_tens
1420
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001421 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1422 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1423
1424 # Invalidate Input/Output list for error if checks.
1425 input_list = [a.name]
1426 output_list = [result_tens.name]
1427 pCount, cCount = op["operands"]
1428 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001429 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1430 self, error_name, input_list, output_list
1431 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432
Les Bell729b0352021-11-24 10:28:21 +00001433 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001434 self.ser,
1435 validator_fcns,
1436 error_name,
1437 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001438 input_shape=a.shape,
1439 output_shape=result_tens.shape,
1440 input_dtype=a.dtype,
1441 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001442 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001443 input_list=input_list,
1444 output_list=output_list,
1445 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001446 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001447 ):
1448 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001449
1450 attr = ts.TosaSerializerAttribute()
1451 attr.TileAttribute(multiples)
1452
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001453 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001454 return result_tens
1455
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001456 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001457
1458 # Create a new indicies tensor
1459 # here with data that doesn't exceed the dimensions of the values tensor
1460
Kevin Cheng550ccc52021-03-03 11:21:43 -08001461 K = values.shape[1] # K
1462 W = self.randInt(
1463 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1464 ) # W
1465 indicies_arr = np.int32(
1466 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1467 ) # (N, W)
1468 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001469
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001470 result_tens = OutputShaper.gatherOp(
1471 self.ser, self.rng, values, indicies, error_name
1472 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001473
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001474 # Invalidate Input/Output list for error if checks.
1475 input_list = [values.name, indicies.name]
1476 output_list = [result_tens.name]
1477 pCount, cCount = op["operands"]
1478 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001479 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1480 self, error_name, input_list, output_list
1481 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001482
Les Bell729b0352021-11-24 10:28:21 +00001483 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001484 self.ser,
1485 validator_fcns,
1486 error_name,
1487 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001488 input_shape=values.shape,
1489 output_shape=result_tens.shape,
1490 input_dtype=values.dtype,
1491 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001492 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001493 input_list=input_list,
1494 output_list=output_list,
1495 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001496 ):
1497 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001498
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001499 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001500
1501 return result_tens
1502
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001503 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001504
1505 # Create a new indicies tensor
1506 # here with data that doesn't exceed the dimensions of the values_in tensor
1507
Kevin Cheng550ccc52021-03-03 11:21:43 -08001508 K = values_in.shape[1] # K
1509 W = input.shape[1] # W
1510 indicies_arr = np.int32(
1511 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1512 ) # (N, W)
1513 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001514
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001515 result_tens = OutputShaper.scatterOp(
1516 self.ser, self.rng, values_in, indicies, input, error_name
1517 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001518
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001519 # Invalidate Input/Output list for error if checks.
1520 input_list = [values_in.name, indicies.name, input.name]
1521 output_list = [result_tens.name]
1522 pCount, cCount = op["operands"]
1523 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001524 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1525 self, error_name, input_list, output_list
1526 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001527
Les Bell729b0352021-11-24 10:28:21 +00001528 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001529 self.ser,
1530 validator_fcns,
1531 error_name,
1532 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001533 input_shape=values_in.shape,
1534 output_shape=result_tens.shape,
1535 input_dtype=values_in.dtype,
1536 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001537 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001538 input_list=input_list,
1539 output_list=output_list,
1540 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001541 ):
1542 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001543
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001544 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001545
Kevin Cheng77d0f762020-11-24 10:26:32 -08001546 return result_tens
1547
Kevin Cheng550ccc52021-03-03 11:21:43 -08001548 def build_resize(
1549 self,
1550 op,
1551 input,
1552 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001553 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001554 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001555 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001556 input_dtype,
1557 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001558 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001559 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001560 ):
1561 result_tens = OutputShaper.resizeOp(
1562 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001563 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001564 input,
1565 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001566 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001567 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001568 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001569 input_dtype,
1570 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001571 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001572 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001573
Matthew Haddon848efb42021-09-09 12:30:53 +01001574 # Invalidate Input/Output list for error if checks.
1575 input_list = [input.name]
1576 output_list = [result_tens.name]
1577 pCount, cCount = op["operands"]
1578 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001579 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1580 self, error_name, input_list, output_list
1581 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001582
Les Bell729b0352021-11-24 10:28:21 +00001583 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001584 self.ser,
1585 validator_fcns,
1586 error_name,
1587 op=op,
1588 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001589 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001590 input_dtype=input_dtype,
1591 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001592 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001593 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001594 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001595 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001596 input_list=input_list,
1597 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001598 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001599 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001600 ):
1601 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001602
Eric Kunzee5e26762020-10-13 16:11:07 -07001603 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001604
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001605 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001606
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001607 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001608 return result_tens
1609
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001610 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1611 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1612 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001613 self.ser.addOperator(
1614 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1615 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001616 return result_tens
1617
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001618 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001619 self.ser.addOutputTensor(val)
1620 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001621
1622 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001623 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001624 result_tens = OutputShaper.typeConversionOp(
1625 self.ser, self.rng, val, out_dtype, error_name
1626 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001627
1628 # Invalidate Input/Output list for error if checks.
1629 input_list = [val.name]
1630 output_list = [result_tens.name]
1631 pCount, cCount = op["operands"]
1632 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001633 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1634 self, error_name, input_list, output_list
1635 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001636
Les Bell729b0352021-11-24 10:28:21 +00001637 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001638 self.ser,
1639 validator_fcns,
1640 error_name,
1641 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001642 input_shape=val.shape,
1643 output_shape=result_tens.shape,
1644 input_dtype=val.dtype,
1645 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001646 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001647 input_list=input_list,
1648 output_list=output_list,
1649 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001650 ):
1651 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001652
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001653 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001654 return result_tens
1655
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001656 def build_rescale(
1657 self,
1658 op,
1659 val,
1660 out_dtype,
1661 scale32,
1662 double_round,
1663 per_channel,
1664 validator_fcns,
1665 error_name,
1666 ):
1667 result_tens = OutputShaper.typeConversionOp(
1668 self.ser, self.rng, val, out_dtype, error_name
1669 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001670
1671 if per_channel:
1672 nc = val.shape[-1]
1673 else:
1674 nc = 1
1675
1676 in_type_width = self.typeWidth(val.dtype)
1677 out_type_width = self.typeWidth(out_dtype)
1678
Kevin Cheng3a478572021-01-22 17:21:02 -08001679 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001680 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001681 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001682 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001683 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001684 in_type_width += 1
1685 elif error_name in [
1686 ErrorIf.InputZeroPointNotZero,
1687 ErrorIf.U16InputZeroPointNotValid,
1688 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001689 input_zp = self.randInt(-128, 128)
1690 if input_zp == 0:
1691 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001692 in_type_width += 1
1693 elif val.dtype == DType.UINT16:
1694 # Must come after ErrorIf.U16InputZeroPointNotValid check
1695 input_zp = self.rng.choice([0, 32768])
1696 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001697 else:
1698 input_zp = 0
1699
Kevin Cheng3a478572021-01-22 17:21:02 -08001700 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001701 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001702 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001703 elif out_dtype == DType.UINT8:
1704 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001705 out_type_width += 1
1706 elif error_name in [
1707 ErrorIf.OutputZeroPointNotZero,
1708 ErrorIf.U16OutputZeroPointNotValid,
1709 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001710 output_zp = self.randInt(-128, 128)
1711 if output_zp == 0:
1712 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001713 out_type_width += 1
1714 elif out_dtype == DType.UINT16:
1715 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1716 output_zp = self.rng.choice([0, 32768])
1717 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001718 else:
1719 output_zp = 0
1720
1721 # Calculate scale based on:
1722 # scale = a *(2^output_width)/(2^input_width))
1723
1724 a = np.float32(self.rng.random(size=[nc]))
1725 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1726
1727 if scale32:
1728 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001729 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001730 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1731 else:
1732 # Cap the scaling at 2^15 - 1 for scale16
1733 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1734
Kevin Cheng550ccc52021-03-03 11:21:43 -08001735 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001736
1737 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1738 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001739 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1740 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001741
1742 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001743 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1744 scale_arr[i], scale32
1745 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001746 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1747 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001748
Kevin Cheng550ccc52021-03-03 11:21:43 -08001749 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001750 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001751 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001752 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001753 assert val.placeholderFilename
1754 values = np.load(
1755 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1756 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001757 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1758 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1759 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1760 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001761 if not np.all(np.array_equal(values, val_adj)):
1762 # Values changed so overwrite file with new values
1763 np.save(
1764 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1765 val_adj,
1766 False,
1767 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001768
Matthew Haddonc2025212021-10-08 21:21:05 +01001769 # Invalidate Input/Output list for error if checks.
1770 input_list = [val.name]
1771 output_list = [result_tens.name]
1772 pCount, cCount = op["operands"]
1773 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001774 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1775 self, error_name, input_list, output_list
1776 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001777
1778 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001779 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001780 self.ser,
1781 validator_fcns,
1782 error_name,
1783 op=op,
1784 input_dtype=val.dtype,
1785 output_dtype=out_dtype,
1786 input_shape=val.shape,
1787 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001788 scale32=scale32,
1789 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001790 input_list=input_list,
1791 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001792 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001793 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001794 ):
1795 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001796
Eric Kunzee5e26762020-10-13 16:11:07 -07001797 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001798 attr.RescaleAttribute(
1799 input_zp,
1800 output_zp,
1801 multiplier_arr,
1802 shift_arr,
1803 scale32,
1804 double_round,
1805 per_channel,
1806 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001807
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001809 return result_tens
1810
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001811 def _get_condition_tensor(self, op, cond, error_name):
1812 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1813 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1814 else:
1815 cond_type = DType.BOOL
1816 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1817 choice = self.rng.choice([1, 2])
1818 if choice == 1:
1819 cond_shape = [2]
1820 else:
1821 cond_shape = [1, 2]
1822 else:
1823 # Must be of size 1 (rank 0)
1824 cond_shape = []
1825 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1826 return cond_tens
1827
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001828 def build_cond_if_const(
1829 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1830 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001831 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001832 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001833 # and fill them with const nodes for the body.
1834
1835 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001836 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001837
1838 # Make then/else tensors
1839 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001840
1841 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001842 if error_name in [
1843 ErrorIf.CondIfOutputListThenGraphMismatch,
1844 ErrorIf.CondIfOutputListElseGraphMismatch,
1845 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001846 incorrect_shape = deepcopy(then_tens.shape)
1847 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001848 incorrect_shape[i] += (
1849 self.rng.choice([-3, -2, 2, 3])
1850 if incorrect_shape[i] > 3
1851 else self.rng.choice([1, 2, 4])
1852 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001853 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1854
Jeremy Johnson18e26662021-07-22 16:15:29 +01001855 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1856 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001857
1858 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001859 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001860
1861 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001862 then_block = "THEN_BLOCK"
1863 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001864 attr = ts.TosaSerializerAttribute()
1865 attr.CondIfAttribute(then_block, else_block)
1866
1867 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001868 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001869
Jerry Ge9e94af82022-10-27 09:57:00 -07001870 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001871 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001872 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1873 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1874 else:
1875 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001876 self.ser.addOutputTensor(then_tens)
1877
Jerry Ge9e94af82022-10-27 09:57:00 -07001878 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001879 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1880 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1881 else:
1882 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001883 self.ser.addOutputTensor(else_tens)
1884
Les Bell729b0352021-11-24 10:28:21 +00001885 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001886 self.ser,
1887 validator_fcns,
1888 error_name,
1889 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001890 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001891 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001892 ):
1893 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001894
Eric Kunzee5e26762020-10-13 16:11:07 -07001895 return result_tens
1896
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001897 def build_cond_if_binary(
1898 self, op, a, b, cond, validator_fcns=None, error_name=None
1899 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001900 # For cond_if with a binary op in the then/else blocks, take a and b and
1901 # alternately add or subtract them based on the condition
1902
1903 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001904 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001905
Kevin Cheng550ccc52021-03-03 11:21:43 -08001906 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001907
1908 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001909 then_block = "THEN_BLOCK"
1910 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001911 attr = ts.TosaSerializerAttribute()
1912 attr.CondIfAttribute(then_block, else_block)
1913
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001914 if error_name in [
1915 ErrorIf.CondIfInputListThenGraphMismatch,
1916 ErrorIf.CondIfInputListElseGraphMismatch,
1917 ErrorIf.CondIfOutputListElseGraphMismatch,
1918 ErrorIf.CondIfOutputListThenGraphMismatch,
1919 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001920 incorrect_shape = a.shape.copy()
1921 for i in range(len(incorrect_shape)):
1922 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1923 incorrect_block_input = deepcopy(a)
1924 incorrect_block_input.shape = incorrect_shape
1925
Eric Kunzee5e26762020-10-13 16:11:07 -07001926 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001927 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001928 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001929 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001930
James Ward24dbc422022-10-19 12:20:31 +01001931 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001932 then_op, else_op = Op.ADD, Op.SUB
1933 elif a.dtype in (DType.INT8, DType.INT16):
1934 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1935 else:
1936 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001937
Les Bell6040b4d2021-10-11 12:50:31 +01001938 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001939 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001940 if (
1941 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1942 and block == then_block
1943 ) or (
1944 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1945 and block == else_block
1946 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001947 self.ser.addInputTensor(incorrect_block_input)
1948 self.ser.addInputTensor(b)
1949 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001950 elif (
1951 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1952 and block == then_block
1953 ) or (
1954 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1955 and block == else_block
1956 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001957 self.ser.addInputTensor(a)
1958 self.ser.addInputTensor(b)
1959 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1960 else:
1961 self.ser.addInputTensor(a)
1962 self.ser.addInputTensor(b)
1963 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001964 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001965
Les Bell729b0352021-11-24 10:28:21 +00001966 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001967 self.ser,
1968 validator_fcns,
1969 error_name,
1970 op=op,
1971 a=a,
1972 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07001973 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001974 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001975 ):
1976 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001977
Eric Kunzee5e26762020-10-13 16:11:07 -07001978 return result_tens
1979
Matthew Haddon630c17c2021-10-14 15:05:41 +01001980 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001981 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001982
Kevin Cheng550ccc52021-03-03 11:21:43 -08001983 cond_block = "COND_BLOCK"
1984 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001985
1986 attr = ts.TosaSerializerAttribute()
1987 attr.WhileLoopAttribute(cond_block, body_block)
1988
1989 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001990 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001991 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001992 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001993
1994 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001995 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1996 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001997 if error_name == ErrorIf.InputListOutputListMismatch:
1998 incorrect_acc = deepcopy(acc)
1999 for i in range(len(incorrect_acc.shape)):
2000 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2001 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2002 else:
2003 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002004
2005 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002006 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002007 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002008 [iter.name, a.name, acc.name],
2009 [iter_out.name, a_out.name, acc_out.name],
2010 attr,
2011 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002012 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002013
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002014 if error_name in [
2015 ErrorIf.InputListCondGraphMismatch,
2016 ErrorIf.InputListBodyGraphInputMismatch,
2017 ErrorIf.InputListBodyGraphOutputMismatch,
2018 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002019 incorrect_iter = deepcopy(iter)
2020 for i in range(len(incorrect_iter.shape)):
2021 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2022 if len(incorrect_iter.shape) == 0:
2023 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2024
2025 incorrect_acc = deepcopy(acc)
2026 for i in range(len(incorrect_acc.shape)):
2027 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2028
Eric Kunzee5e26762020-10-13 16:11:07 -07002029 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002030 self.ser.addBasicBlock(cond_block)
2031
Matthew Haddon630c17c2021-10-14 15:05:41 +01002032 if error_name == ErrorIf.InputListCondGraphMismatch:
2033 self.ser.addInputTensor(incorrect_iter)
2034 self.ser.addInputTensor(a)
2035 self.ser.addInputTensor(incorrect_acc)
2036 else:
2037 self.ser.addInputTensor(iter)
2038 self.ser.addInputTensor(a)
2039 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002040 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002041
2042 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002043 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002044 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002045 cond_type = DType.BOOL
2046 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2047 choice = self.rng.choice([1, 2])
2048 if choice == 1:
2049 cond_shape = [3]
2050 else:
2051 cond_shape = [1, 2]
2052 else:
2053 cond_shape = []
2054 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002055
Kevin Cheng550ccc52021-03-03 11:21:43 -08002056 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002057
2058 # BODY block (input: a, acc, iter, output: a, acc, iter)
2059 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002060 self.ser.addBasicBlock(body_block)
2061
Matthew Haddon630c17c2021-10-14 15:05:41 +01002062 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2063 self.ser.addInputTensor(incorrect_iter)
2064 self.ser.addInputTensor(a)
2065 self.ser.addInputTensor(incorrect_acc)
2066 else:
2067 self.ser.addInputTensor(iter)
2068 self.ser.addInputTensor(a)
2069 self.ser.addInputTensor(acc)
2070
Kevin Cheng550ccc52021-03-03 11:21:43 -08002071 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002072
2073 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002074 iter_body_out = self.ser.addIntermediate(
2075 incorrect_iter.shape, incorrect_iter.dtype
2076 )
2077 acc_body_out = self.ser.addIntermediate(
2078 incorrect_acc.shape, incorrect_acc.dtype
2079 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002080 else:
2081 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2082 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2083
Eric Kunzee5e26762020-10-13 16:11:07 -07002084 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2085 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2086 self.ser.addOutputTensor(iter_body_out)
2087 self.ser.addOutputTensor(a)
2088 self.ser.addOutputTensor(acc_body_out)
2089
Les Bell729b0352021-11-24 10:28:21 +00002090 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002091 self.ser,
2092 validator_fcns,
2093 error_name,
2094 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002095 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002096 ):
2097 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002098
Eric Kunzee5e26762020-10-13 16:11:07 -07002099 return acc_out
2100
Luke Hutton57287132023-02-06 14:54:18 +00002101 def build_fft2d(
2102 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2103 ):
2104 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2105
2106 input_names = [val1.name, val2.name]
2107 pCount, cCount = op["operands"]
2108 num_operands = pCount + cCount
2109
2110 output_names = [res.name for res in results]
2111 output_shapes = [res.shape for res in results]
2112 output_dtypes = [res.dtype for res in results]
2113
2114 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2115 self, error_name, input_names, output_names
2116 )
2117
2118 if not TosaErrorValidator.evValidateErrorIfs(
2119 self.ser,
2120 validator_fcns,
2121 error_name,
2122 op=op,
2123 inverse=inverse,
2124 input1=val1,
2125 input2=val2,
2126 input_shape=val1.shape,
2127 input_dtype=val1.dtype,
2128 output_shape=output_shapes,
2129 output_dtype=output_dtypes,
2130 result_tensors=results,
2131 input_list=input_names,
2132 output_list=output_names,
2133 num_operands=num_operands,
2134 ):
2135 return None
2136
2137 attr = ts.TosaSerializerAttribute()
2138 attr.FFTAttribute(inverse)
2139
2140 self.ser.addOperator(op["op"], input_names, output_names, attr)
2141 return results
2142
Luke Hutton261b7b62023-01-10 14:50:31 +00002143 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2144 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2145
2146 input_names = [val.name]
2147 pCount, cCount = op["operands"]
2148 num_operands = pCount + cCount
2149
2150 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002151 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002152 output_dtypes = [res.dtype for res in results]
2153
2154 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2155 self, error_name, input_names, output_names
2156 )
2157
2158 if not TosaErrorValidator.evValidateErrorIfs(
2159 self.ser,
2160 validator_fcns,
2161 error_name,
2162 op=op,
2163 input_shape=val.shape,
2164 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002165 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002166 output_dtype=output_dtypes,
2167 result_tensors=results,
2168 input_list=input_names,
2169 output_list=output_names,
2170 num_operands=num_operands,
2171 ):
2172 return None
2173
2174 self.ser.addOperator(op["op"], input_names, output_names)
2175 return results
2176
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002177 def create_filter_lists(
2178 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2179 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002180 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2181 default_test_rank_range = range(1, 5)
2182 if not shapeFilter:
2183 shapeFilter = [None]
2184
2185 # Calculate the filters based on what is requested and what the operator allows
2186 rmin, rmax = op["rank"]
2187 if rankFilter is not None:
2188 cleanRankFilter = []
2189 # Ensure rankFilter values are allowed by operator
2190 for rank in rankFilter:
2191 if rank >= rmin and rank <= rmax:
2192 cleanRankFilter.append(rank)
2193 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002194 # Ensure default behaviour is bounded by default range or by operator,
2195 # whichever is the smaller range of ranks.
2196 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002197 cleanRankFilter = (
2198 opRankRange
2199 if len(opRankRange) <= len(default_test_rank_range)
2200 else default_test_rank_range
2201 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002202 else:
2203 cleanRankFilter = range(rmin, rmax + 1)
2204
2205 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002206
Matthew Haddon1c00b712021-10-01 15:51:03 +01002207 if dtypeFilter is not None:
2208 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002209 # Create list of operator dtypes filtered by requested dtypes
2210 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002211 if dtype in dtypeFilter or (
2212 isinstance(dtype, list) and dtype[0] in dtypeFilter
2213 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002214 cleanDtypeFilter.append(dtype)
2215 else:
2216 cleanDtypeFilter = dtypes
2217
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002218 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002219 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002220 "shapeFilter": shapeFilter,
2221 "rankFilter": cleanRankFilter,
2222 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002223 }
2224 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002225 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002226 if validator is not None:
2227 validator_info = validator(check=False, op=op)
2228 else:
2229 return None
2230
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002231 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002232
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002233 # Set parameters as required
2234 if error_arguments["rank"] is not None:
2235 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002236 else:
2237 rankFilter = cleanRankFilter
2238
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002239 if error_arguments["dtype"] is not None:
2240 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002241 else:
2242 dtypeFilter = cleanDtypeFilter
2243
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002244 if error_arguments["shape"] is not None:
2245 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002246 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002247 shapeFilter = shapeFilter[
2248 :2
2249 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002250
2251 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002252 "shapeFilter": shapeFilter,
2253 "rankFilter": rankFilter,
2254 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002255 }
2256 return filterDict
2257
Kevin Cheng550ccc52021-03-03 11:21:43 -08002258 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002259 self,
2260 opName,
2261 shapeFilter=[None],
2262 rankFilter=None,
2263 dtypeFilter=None,
2264 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002265 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002266
2267 try:
2268 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002269 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002270 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002271
2272 # Initialize a new random number generator
2273 self.rng = np.random.default_rng(self.random_seed)
2274
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002275 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002276
Eric Kunzee5e26762020-10-13 16:11:07 -07002277 # Test list consists of a tuple of:
2278 # (opName, testNameStr, dtype, shapeList, argumentsList)
2279 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002280 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002281 error_if_validators = op["error_if_validators"]
2282 else:
2283 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002284
Matthew Haddon1c00b712021-10-01 15:51:03 +01002285 for validator in error_if_validators:
2286 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002287 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002288 else:
2289 error_name = None
2290
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002291 filterDict = self.create_filter_lists(
2292 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2293 )
2294 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002295 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002296 cleanRankFilter = filterDict["rankFilter"]
2297 cleanDtypeFilter = filterDict["dtypeFilter"]
2298 cleanShapeFilter = filterDict["shapeFilter"]
2299 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002300
2301 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002302 for t in cleanDtypeFilter:
2303 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002304 # Filter out by rank
2305 if shape is not None and len(shape) != r:
2306 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002307 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002308 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002309
Matthew Haddon74567092021-07-16 15:38:20 +01002310 shapeStr = self.shapeStr(shapeList[0])
2311 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002312
Matthew Haddon74567092021-07-16 15:38:20 +01002313 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2314 argList = []
2315 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002316 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002317 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002318 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002319
Matthew Haddon74567092021-07-16 15:38:20 +01002320 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002321 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002322 if argStr:
2323 testStr = "{}_{}_{}_{}".format(
2324 opName, shapeStr, typeStr, argStr
2325 )
2326 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002327 testStr = "{}_{}_{}".format(
2328 opName, shapeStr, typeStr
2329 )
2330 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002331 if argStr:
2332 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2333 opName, error_name, shapeStr, typeStr, argStr
2334 )
2335 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002336 testStr = "{}_ERRORIF_{}_{}_{}".format(
2337 opName, error_name, shapeStr, typeStr
2338 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002339
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002340 testList.append(
2341 (opName, testStr, t, error_name, shapeList, args)
2342 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002343
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002344 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002345 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2346 if "invalid_test_validators" in op:
2347 invalid_test_validators = op["invalid_test_validators"]
2348 clean_testList = []
2349 for test in testList:
2350 for validator_fcn in invalid_test_validators:
2351 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002352 if validator_fcn(
2353 opName=test[0],
2354 input_dtype=test[2],
2355 shapeList=test[4],
2356 args=test[5],
2357 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002358 remove_test = True
2359 if not remove_test:
2360 clean_testList.append(test)
2361 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002362
2363 return testList
2364
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002365 def serializeTest(
2366 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2367 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002368 try:
2369 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002370 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002371 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002372
2373 # Create a serializer
2374 self.createSerializer(opName, testStr)
2375
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002376 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002377 if "error_if_validators" in op:
2378 error_if_validators = op["error_if_validators"]
2379 else:
2380 error_if_validators = None
2381
Kevin Cheng550ccc52021-03-03 11:21:43 -08002382 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002383 num_operands = pCount + cCount
2384
2385 if isinstance(dtype_or_dtypeList, list):
2386 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002387 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002388 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002389 else:
2390 dtypeList = [dtype_or_dtypeList] * (num_operands)
2391
Kevin Cheng93a16282021-08-31 16:14:03 -07002392 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002393 assert (
2394 len(shapeList) == num_operands
2395 ), "shapeList length {} must match number of operands {}".format(
2396 len(shapeList), num_operands
2397 )
2398 assert (
2399 len(dtypeList) == num_operands
2400 ), "dtypeList length {} must match number of operands {}".format(
2401 len(dtypeList), num_operands
2402 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002403
2404 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002405 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002406 except KeyError:
2407 qgen = None
2408
2409 # Build the random tensor operands and the test
2410 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002411
Matthew Haddon1c00b712021-10-01 15:51:03 +01002412 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002413 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002414 else:
2415 qinfo = None
2416
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002417 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002418
Matthew Haddon1c00b712021-10-01 15:51:03 +01002419 try:
2420 if error_if_validators is None:
2421 if qinfo is not None:
2422 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2423 else:
2424 resultName = build_fcn(self, op, *tens, *testArgs)
2425 else:
2426 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002427 resultName = build_fcn(
2428 self,
2429 op,
2430 *tens,
2431 *testArgs,
2432 validator_fcns=error_if_validators,
2433 error_name=error_name,
2434 qinfo=qinfo,
2435 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002436 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002437 resultName = build_fcn(
2438 self,
2439 op,
2440 *tens,
2441 *testArgs,
2442 validator_fcns=error_if_validators,
2443 error_name=error_name,
2444 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002445 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002446 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002447 raise e
2448
Les Bell729b0352021-11-24 10:28:21 +00002449 if resultName:
2450 # The test is valid, serialize it
2451 self.serialize("test")
2452 else:
2453 # The test is not valid
2454 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002455
Eric Kunzee5e26762020-10-13 16:11:07 -07002456 def createDynamicOpLists(self):
2457
Jeremy Johnson00423432022-09-12 17:27:37 +01002458 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2459 # Already created these lists (can occur when class is initialized more than once)
2460 return
2461
Eric Kunzee5e26762020-10-13 16:11:07 -07002462 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002463 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002464
Kevin Cheng1533b852021-09-01 12:51:58 -07002465 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002466 testName = "conv2d_{}x{}".format(k[0], k[1])
2467 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2468 self.TOSA_OP_LIST[testName]["filter"] = k
2469 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002470
Kevin Cheng550ccc52021-03-03 11:21:43 -08002471 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2472 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2473 "depthwise_conv2d_TEMPLATE"
2474 ].copy()
2475 self.TOSA_OP_LIST[testName]["filter"] = k
2476 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002477
Kevin Cheng550ccc52021-03-03 11:21:43 -08002478 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2479 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2480 "transpose_conv2d_TEMPLATE"
2481 ].copy()
2482 self.TOSA_OP_LIST[testName]["filter"] = k
2483 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002484
Kevin Cheng1533b852021-09-01 12:51:58 -07002485 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2486 for k in KERNELS_3D:
2487 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2488 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2489 self.TOSA_OP_LIST[testName]["filter"] = k
2490 self.TOSA_OP_LIST[testName]["template"] = False
2491
Eric Kunzee5e26762020-10-13 16:11:07 -07002492 # Delete any templates after having created any dynamic ops
2493 # This is a two-pass operation because it's bad practice to delete
2494 # keys from dictionaries while iterating
2495 keyList = []
2496 for k in self.TOSA_OP_LIST:
2497 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002498 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002499 keyList.append(k)
2500 continue
2501 except KeyError:
2502 pass
2503
2504 for k in keyList:
2505 del self.TOSA_OP_LIST[k]
2506
2507 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002508 """Fill in default fields for ops if they aren't already specified.
2509 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002510 for op in self.TOSA_OP_LIST:
2511
2512 # Required fields
2513 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002514 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002515 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002516 raise Exception(
2517 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2518 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002519
2520 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002521 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002522 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002523 raise Exception(
2524 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2525 op
2526 )
2527 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002528
2529 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002530 _ = self.TOSA_OP_LIST[op]["types"]
2531 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002532 raise Exception(
2533 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2534 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002535
2536 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002537 _ = self.TOSA_OP_LIST[op]["op"]
2538 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 raise Exception(
2540 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2541 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002542
2543 # Put in default rank range, if missing
2544 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002545 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002546 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002547 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002548
2549 # Tensor operator list
2550 # 'op': op name
2551 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002552 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2553 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002554 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2555 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002556 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002557
Kevin Cheng550ccc52021-03-03 11:21:43 -08002558 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002559 TYPE_INT_FP = [
2560 DType.INT8,
2561 DType.INT16,
2562 DType.INT32,
2563 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002564 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002565 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002566 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002567
Kevin Cheng550ccc52021-03-03 11:21:43 -08002568 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002569 TYPE_FI32 = [
2570 DType.FP32,
2571 DType.FP16,
2572 DType.BF16,
2573 DType.INT32,
2574 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002575 TYPE_FIB = [
2576 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002577 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002578 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002579 DType.INT8,
2580 DType.INT16,
2581 DType.INT32,
2582 DType.BOOL,
2583 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002584 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002585
James Ward24dbc422022-10-19 12:20:31 +01002586 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002587
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002588 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002589 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002590 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002591 [DType.INT8, DType.INT8, DType.INT32],
2592 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002593 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002594 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002595 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002596 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002597 ]
2598
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002599 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002600
2601 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002602 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002603 "argmax": {
2604 "op": Op.ARGMAX,
2605 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002606 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002607 "build_fcn": (
2608 build_argmax,
2609 TosaTensorGen.tgBasic,
2610 TosaTensorValuesGen.tvgDefault,
2611 TosaArgGen.agAxis,
2612 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002613 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002614 "error_if_validators": (
2615 TosaErrorValidator.evAxisSmallerZero,
2616 TosaErrorValidator.evAxisLargerRank,
2617 TosaErrorValidator.evArgmaxOutputRankMismatch,
2618 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2619 TosaErrorValidator.evWrongRank,
2620 TosaErrorValidator.evWrongInputType,
2621 TosaErrorValidator.evWrongOutputType,
2622 TosaErrorValidator.evWrongInputList,
2623 TosaErrorValidator.evWrongOutputList,
2624 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002625 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002626 "avg_pool2d": {
2627 "op": Op.AVG_POOL2D,
2628 "operands": (1, 0),
2629 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002630 "build_fcn": (
2631 build_pool2d,
2632 TosaTensorGen.tgNHWC,
2633 TosaTensorValuesGen.tvgDefault,
2634 TosaArgGen.agPooling,
2635 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002636 "qgen": TosaQuantGen.qgUnary,
2637 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002638 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002639 "error_if_validators": (
2640 TosaErrorValidator.evKernelSmallerOne,
2641 TosaErrorValidator.evStrideSmallerOne,
2642 TosaErrorValidator.evPadSmallerZero,
2643 TosaErrorValidator.evWrongRank,
2644 TosaErrorValidator.evWrongInputType,
2645 TosaErrorValidator.evWrongOutputType,
2646 TosaErrorValidator.evWrongInputList,
2647 TosaErrorValidator.evWrongOutputList,
2648 TosaErrorValidator.evInputZeroPointNotZero,
2649 TosaErrorValidator.evOutputZeroPointNotZero,
2650 TosaErrorValidator.evPadLargerEqualKernel,
2651 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002652 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002653 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002654 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002655 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002656 "conv2d_TEMPLATE": {
2657 "op": Op.CONV2D,
2658 "operands": (1, 2),
2659 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002660 "build_fcn": (
2661 build_conv2d,
2662 TosaTensorGen.tgConv2D,
2663 TosaTensorValuesGen.tvgDefault,
2664 TosaArgGen.agConv,
2665 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002666 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002667 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002668 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2669 "error_if_validators": (
2670 TosaErrorValidator.evWrongInputType,
2671 TosaErrorValidator.evWrongOutputType,
2672 TosaErrorValidator.evWrongInputList,
2673 TosaErrorValidator.evWrongOutputList,
2674 TosaErrorValidator.evInputZeroPointNotZero,
2675 TosaErrorValidator.evWeightZeroPointNotZero,
2676 TosaErrorValidator.evPadSmallerZero,
2677 TosaErrorValidator.evStrideSmallerOne,
2678 TosaErrorValidator.evDilationSmallerOne,
2679 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002680 TosaErrorValidator.evConvOutputShapeMismatch,
2681 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002682 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002683 "template": True,
2684 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002685 # Templated operator. Filled in by createDynamicOpLists
2686 "conv3d_TEMPLATE": {
2687 "op": Op.CONV3D,
2688 "operands": (1, 2),
2689 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002690 "build_fcn": (
2691 build_conv3d,
2692 TosaTensorGen.tgConv3D,
2693 TosaTensorValuesGen.tvgDefault,
2694 TosaArgGen.agConv,
2695 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002696 "qgen": TosaQuantGen.qgConv,
2697 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002698 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2699 "error_if_validators": (
2700 TosaErrorValidator.evWrongInputType,
2701 TosaErrorValidator.evWrongOutputType,
2702 TosaErrorValidator.evWrongInputList,
2703 TosaErrorValidator.evWrongOutputList,
2704 TosaErrorValidator.evInputZeroPointNotZero,
2705 TosaErrorValidator.evWeightZeroPointNotZero,
2706 TosaErrorValidator.evPadSmallerZero,
2707 TosaErrorValidator.evStrideSmallerOne,
2708 TosaErrorValidator.evDilationSmallerOne,
2709 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002710 TosaErrorValidator.evConvOutputShapeMismatch,
2711 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002712 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002713 "template": True,
2714 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002715 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002716 "depthwise_conv2d_TEMPLATE": {
2717 "op": Op.DEPTHWISE_CONV2D,
2718 "operands": (1, 2),
2719 "filter": [1, 1],
2720 "rank": (4, 4),
2721 "build_fcn": (
2722 build_depthwise_conv2d,
2723 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002724 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002725 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002726 ),
2727 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002728 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002729 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2730 "error_if_validators": (
2731 TosaErrorValidator.evWrongInputType,
2732 TosaErrorValidator.evWrongOutputType,
2733 TosaErrorValidator.evWrongInputList,
2734 TosaErrorValidator.evWrongOutputList,
2735 TosaErrorValidator.evInputZeroPointNotZero,
2736 TosaErrorValidator.evWeightZeroPointNotZero,
2737 TosaErrorValidator.evPadSmallerZero,
2738 TosaErrorValidator.evStrideSmallerOne,
2739 TosaErrorValidator.evDilationSmallerOne,
2740 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002741 TosaErrorValidator.evConvOutputShapeMismatch,
2742 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002743 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002744 "template": True,
2745 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002746 "fully_connected": {
2747 "op": Op.FULLY_CONNECTED,
2748 "operands": (1, 2),
2749 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002750 "build_fcn": (
2751 build_fully_connected,
2752 TosaTensorGen.tgFullyConnected,
2753 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002754 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002755 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002756 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002757 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002758 "error_if_validators": (
2759 TosaErrorValidator.evInputZeroPointNotZero,
2760 TosaErrorValidator.evWeightZeroPointNotZero,
2761 TosaErrorValidator.evWrongRank,
2762 TosaErrorValidator.evWrongInputType,
2763 TosaErrorValidator.evWrongOutputType,
2764 TosaErrorValidator.evWrongInputList,
2765 TosaErrorValidator.evWrongOutputList,
2766 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002767 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002768 "matmul": {
2769 "op": Op.MATMUL,
2770 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002771 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002772 "build_fcn": (
2773 build_matmul,
2774 TosaTensorGen.tgMatmul,
2775 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002776 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002777 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002778 "qgen": TosaQuantGen.qgMatmul,
2779 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002780 "error_if_validators": (
2781 TosaErrorValidator.evInputZeroPointNotZero,
2782 TosaErrorValidator.evWrongRank,
2783 TosaErrorValidator.evWrongInputType,
2784 TosaErrorValidator.evWrongOutputType,
2785 TosaErrorValidator.evWrongInputList,
2786 TosaErrorValidator.evWrongOutputList,
2787 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002788 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002789 "max_pool2d": {
2790 "op": Op.MAX_POOL2D,
2791 "operands": (1, 0),
2792 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002793 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002794 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002795 TosaTensorGen.tgNHWC,
2796 TosaTensorValuesGen.tvgDefault,
2797 TosaArgGen.agPooling,
2798 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002799 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002800 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002801 "error_if_validators": (
2802 TosaErrorValidator.evKernelSmallerOne,
2803 TosaErrorValidator.evStrideSmallerOne,
2804 TosaErrorValidator.evPadSmallerZero,
2805 TosaErrorValidator.evWrongRank,
2806 TosaErrorValidator.evWrongInputType,
2807 TosaErrorValidator.evWrongOutputType,
2808 TosaErrorValidator.evWrongInputList,
2809 TosaErrorValidator.evWrongOutputList,
2810 TosaErrorValidator.evPadLargerEqualKernel,
2811 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002812 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002813 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002814 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002815 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002816 "transpose_conv2d_TEMPLATE": {
2817 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002818 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002819 "rank": (4, 4),
2820 "build_fcn": (
2821 build_transpose_conv2d,
2822 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002823 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002824 TosaArgGen.agTransposeConv2D,
2825 ),
2826 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002827 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002828 "invalid_test_validators": (
2829 TosaInvalidValidator.ivHeightWidthInvalid,
2830 TosaInvalidValidator.ivNonPositiveOutputShape,
2831 ),
2832 "error_if_validators": (
2833 TosaErrorValidator.evWrongInputType,
2834 TosaErrorValidator.evWrongOutputType,
2835 TosaErrorValidator.evWrongInputList,
2836 TosaErrorValidator.evWrongOutputList,
2837 TosaErrorValidator.evInputZeroPointNotZero,
2838 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002839 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002840 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002841 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002842 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002843 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002844 "template": True,
2845 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002846 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002847 "clamp": {
2848 "op": Op.CLAMP,
2849 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002850 "build_fcn": (
2851 build_clamp,
2852 TosaTensorGen.tgBasic,
2853 TosaTensorValuesGen.tvgDefault,
2854 None,
2855 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002856 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002857 "error_if_validators": (
2858 TosaErrorValidator.evMaxSmallerMin,
2859 TosaErrorValidator.evWrongInputType,
2860 TosaErrorValidator.evWrongOutputType,
2861 TosaErrorValidator.evWrongInputList,
2862 TosaErrorValidator.evWrongOutputList,
2863 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002864 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002865 "sigmoid": {
2866 "op": Op.SIGMOID,
2867 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002868 "build_fcn": (
2869 build_sigmoid,
2870 TosaTensorGen.tgBasic,
2871 TosaTensorValuesGen.tvgDefault,
2872 None,
2873 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002874 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002875 "error_if_validators": (
2876 TosaErrorValidator.evWrongInputType,
2877 TosaErrorValidator.evWrongOutputType,
2878 TosaErrorValidator.evWrongInputList,
2879 TosaErrorValidator.evWrongOutputList,
2880 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002881 },
2882 "tanh": {
2883 "op": Op.TANH,
2884 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002885 "build_fcn": (
2886 build_tanh,
2887 TosaTensorGen.tgBasic,
2888 TosaTensorValuesGen.tvgDefault,
2889 None,
2890 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002891 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002892 "error_if_validators": (
2893 TosaErrorValidator.evWrongInputType,
2894 TosaErrorValidator.evWrongOutputType,
2895 TosaErrorValidator.evWrongInputList,
2896 TosaErrorValidator.evWrongOutputList,
2897 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002898 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002899 # Elementwise Binary Operators
2900 "add": {
2901 "op": Op.ADD,
2902 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002903 "build_fcn": (
2904 build_binary_broadcast,
2905 TosaTensorGen.tgBroadcastFuzz,
2906 TosaTensorValuesGen.tvgAddSub,
2907 None,
2908 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002909 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002910 "error_if_validators": (
2911 TosaErrorValidator.evRankMismatch,
2912 TosaErrorValidator.evWrongInputType,
2913 TosaErrorValidator.evWrongOutputType,
2914 TosaErrorValidator.evWrongInputList,
2915 TosaErrorValidator.evWrongOutputList,
2916 TosaErrorValidator.evDimensionMismatch,
2917 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002918 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002919 "arithmetic_right_shift": {
2920 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2921 "operands": (2, 0),
2922 "build_fcn": (
2923 build_arithmetic_right_shift,
2924 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002925 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002926 TosaArgGen.agArithmeticRightShift,
2927 ),
2928 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002929 "error_if_validators": (
2930 TosaErrorValidator.evRankMismatch,
2931 TosaErrorValidator.evWrongInputType,
2932 TosaErrorValidator.evWrongOutputType,
2933 TosaErrorValidator.evWrongInputList,
2934 TosaErrorValidator.evWrongOutputList,
2935 TosaErrorValidator.evDimensionMismatch,
2936 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002937 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002938 "bitwise_and": {
2939 "op": Op.BITWISE_AND,
2940 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002941 "build_fcn": (
2942 build_binary_broadcast,
2943 TosaTensorGen.tgBroadcastFuzz,
2944 TosaTensorValuesGen.tvgDefault,
2945 None,
2946 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002947 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002948 "error_if_validators": (
2949 TosaErrorValidator.evRankMismatch,
2950 TosaErrorValidator.evWrongInputType,
2951 TosaErrorValidator.evWrongOutputType,
2952 TosaErrorValidator.evWrongInputList,
2953 TosaErrorValidator.evWrongOutputList,
2954 TosaErrorValidator.evDimensionMismatch,
2955 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002956 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002957 "bitwise_or": {
2958 "op": Op.BITWISE_OR,
2959 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002960 "build_fcn": (
2961 build_binary_broadcast,
2962 TosaTensorGen.tgBroadcastFuzz,
2963 TosaTensorValuesGen.tvgDefault,
2964 None,
2965 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002966 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002967 "error_if_validators": (
2968 TosaErrorValidator.evRankMismatch,
2969 TosaErrorValidator.evWrongInputType,
2970 TosaErrorValidator.evWrongOutputType,
2971 TosaErrorValidator.evWrongInputList,
2972 TosaErrorValidator.evWrongOutputList,
2973 TosaErrorValidator.evDimensionMismatch,
2974 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002975 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002976 "bitwise_xor": {
2977 "op": Op.BITWISE_XOR,
2978 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002979 "build_fcn": (
2980 build_binary_broadcast,
2981 TosaTensorGen.tgBroadcastFuzz,
2982 TosaTensorValuesGen.tvgDefault,
2983 None,
2984 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002985 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002986 "error_if_validators": (
2987 TosaErrorValidator.evRankMismatch,
2988 TosaErrorValidator.evWrongInputType,
2989 TosaErrorValidator.evWrongOutputType,
2990 TosaErrorValidator.evWrongInputList,
2991 TosaErrorValidator.evWrongOutputList,
2992 TosaErrorValidator.evDimensionMismatch,
2993 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002994 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002995 "intdiv": {
2996 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002997 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002998 "build_fcn": (
2999 build_binary_broadcast,
3000 TosaTensorGen.tgBroadcastFuzz,
3001 TosaTensorValuesGen.tvgIntDiv,
3002 None,
3003 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003004 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003005 "error_if_validators": (
3006 TosaErrorValidator.evRankMismatch,
3007 TosaErrorValidator.evWrongInputType,
3008 TosaErrorValidator.evWrongOutputType,
3009 TosaErrorValidator.evWrongInputList,
3010 TosaErrorValidator.evWrongOutputList,
3011 TosaErrorValidator.evDimensionMismatch,
3012 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003013 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003014 "logical_and": {
3015 "op": Op.LOGICAL_AND,
3016 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003017 "build_fcn": (
3018 build_binary_broadcast,
3019 TosaTensorGen.tgBroadcastFuzz,
3020 TosaTensorValuesGen.tvgDefault,
3021 None,
3022 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003023 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003024 "error_if_validators": (
3025 TosaErrorValidator.evRankMismatch,
3026 TosaErrorValidator.evWrongInputType,
3027 TosaErrorValidator.evWrongOutputType,
3028 TosaErrorValidator.evWrongInputList,
3029 TosaErrorValidator.evWrongOutputList,
3030 TosaErrorValidator.evDimensionMismatch,
3031 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003032 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003033 "logical_left_shift": {
3034 "op": Op.LOGICAL_LEFT_SHIFT,
3035 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003036 "build_fcn": (
3037 build_binary_broadcast,
3038 TosaTensorGen.tgBroadcastFuzz,
3039 TosaTensorValuesGen.tvgLogicalShift,
3040 None,
3041 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003042 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003043 "error_if_validators": (
3044 TosaErrorValidator.evRankMismatch,
3045 TosaErrorValidator.evWrongInputType,
3046 TosaErrorValidator.evWrongOutputType,
3047 TosaErrorValidator.evWrongInputList,
3048 TosaErrorValidator.evWrongOutputList,
3049 TosaErrorValidator.evDimensionMismatch,
3050 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003051 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003052 "logical_right_shift": {
3053 "op": Op.LOGICAL_RIGHT_SHIFT,
3054 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003055 "build_fcn": (
3056 build_binary_broadcast,
3057 TosaTensorGen.tgBroadcastFuzz,
3058 TosaTensorValuesGen.tvgLogicalShift,
3059 None,
3060 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003061 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003062 "error_if_validators": (
3063 TosaErrorValidator.evRankMismatch,
3064 TosaErrorValidator.evWrongInputType,
3065 TosaErrorValidator.evWrongOutputType,
3066 TosaErrorValidator.evWrongInputList,
3067 TosaErrorValidator.evWrongOutputList,
3068 TosaErrorValidator.evDimensionMismatch,
3069 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003071 "logical_or": {
3072 "op": Op.LOGICAL_OR,
3073 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003074 "build_fcn": (
3075 build_binary_broadcast,
3076 TosaTensorGen.tgBroadcastFuzz,
3077 TosaTensorValuesGen.tvgDefault,
3078 None,
3079 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003080 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003081 "error_if_validators": (
3082 TosaErrorValidator.evRankMismatch,
3083 TosaErrorValidator.evWrongInputType,
3084 TosaErrorValidator.evWrongOutputType,
3085 TosaErrorValidator.evWrongInputList,
3086 TosaErrorValidator.evWrongOutputList,
3087 TosaErrorValidator.evDimensionMismatch,
3088 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003089 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003090 "logical_xor": {
3091 "op": Op.LOGICAL_XOR,
3092 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003093 "build_fcn": (
3094 build_binary_broadcast,
3095 TosaTensorGen.tgBroadcastFuzz,
3096 TosaTensorValuesGen.tvgDefault,
3097 None,
3098 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003099 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003100 "error_if_validators": (
3101 TosaErrorValidator.evRankMismatch,
3102 TosaErrorValidator.evWrongInputType,
3103 TosaErrorValidator.evWrongOutputType,
3104 TosaErrorValidator.evWrongInputList,
3105 TosaErrorValidator.evWrongOutputList,
3106 TosaErrorValidator.evDimensionMismatch,
3107 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003108 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003109 "maximum": {
3110 "op": Op.MAXIMUM,
3111 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003112 "build_fcn": (
3113 build_binary_broadcast,
3114 TosaTensorGen.tgBroadcastFuzz,
3115 TosaTensorValuesGen.tvgDefault,
3116 None,
3117 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003118 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003119 "error_if_validators": (
3120 TosaErrorValidator.evRankMismatch,
3121 TosaErrorValidator.evWrongInputType,
3122 TosaErrorValidator.evWrongOutputType,
3123 TosaErrorValidator.evWrongInputList,
3124 TosaErrorValidator.evWrongOutputList,
3125 TosaErrorValidator.evDimensionMismatch,
3126 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003127 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003128 "minimum": {
3129 "op": Op.MINIMUM,
3130 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003131 "build_fcn": (
3132 build_binary_broadcast,
3133 TosaTensorGen.tgBroadcastFuzz,
3134 TosaTensorValuesGen.tvgDefault,
3135 None,
3136 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003137 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003138 "error_if_validators": (
3139 TosaErrorValidator.evRankMismatch,
3140 TosaErrorValidator.evWrongInputType,
3141 TosaErrorValidator.evWrongOutputType,
3142 TosaErrorValidator.evWrongInputList,
3143 TosaErrorValidator.evWrongOutputList,
3144 TosaErrorValidator.evDimensionMismatch,
3145 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003147 "mul": {
3148 "op": Op.MUL,
3149 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003150 "build_fcn": (
3151 build_mul,
3152 TosaTensorGen.tgBroadcastFuzz,
3153 TosaTensorValuesGen.tvgMul,
3154 TosaArgGen.agMul,
3155 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003156 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003157 "error_if_validators": (
3158 TosaErrorValidator.evWrongInputType,
3159 TosaErrorValidator.evWrongOutputType,
3160 TosaErrorValidator.evWrongInputList,
3161 TosaErrorValidator.evWrongOutputList,
3162 TosaErrorValidator.evRankMismatch,
3163 TosaErrorValidator.evDimensionMismatch,
3164 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003165 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003166 "pow": {
3167 "op": Op.POW,
3168 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003169 "build_fcn": (
3170 build_binary_broadcast,
3171 TosaTensorGen.tgBroadcastFuzz,
3172 TosaTensorValuesGen.tvgDefault,
3173 None,
3174 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003175 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003176 "error_if_validators": (
3177 TosaErrorValidator.evRankMismatch,
3178 TosaErrorValidator.evWrongInputType,
3179 TosaErrorValidator.evWrongOutputType,
3180 TosaErrorValidator.evWrongInputList,
3181 TosaErrorValidator.evWrongOutputList,
3182 TosaErrorValidator.evDimensionMismatch,
3183 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003184 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003185 "sub": {
3186 "op": Op.SUB,
3187 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003188 "build_fcn": (
3189 build_binary_broadcast,
3190 TosaTensorGen.tgBroadcastFuzz,
3191 TosaTensorValuesGen.tvgAddSub,
3192 None,
3193 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003194 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003195 "error_if_validators": (
3196 TosaErrorValidator.evRankMismatch,
3197 TosaErrorValidator.evWrongInputType,
3198 TosaErrorValidator.evWrongOutputType,
3199 TosaErrorValidator.evWrongInputList,
3200 TosaErrorValidator.evWrongOutputList,
3201 TosaErrorValidator.evDimensionMismatch,
3202 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003203 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003204 "table": {
3205 "op": Op.TABLE,
3206 # Use the automatic generation functions to create the input array
3207 # but create the table tensor in the build function, as it may be
3208 # a different type from the input
3209 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003210 "build_fcn": (
3211 build_table,
3212 TosaTensorGen.tgBasic,
3213 TosaTensorValuesGen.tvgDefault,
3214 TosaArgGen.agTable,
3215 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003216 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003217 "error_if_validators": (
3218 TosaErrorValidator.evWrongInputType,
3219 TosaErrorValidator.evWrongOutputType,
3220 TosaErrorValidator.evWrongInputList,
3221 TosaErrorValidator.evWrongOutputList,
3222 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 # Elementwise Unary operators
3225 "abs": {
3226 "op": Op.ABS,
3227 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003228 "build_fcn": (
3229 build_unary,
3230 TosaTensorGen.tgBasic,
3231 TosaTensorValuesGen.tvgDefault,
3232 None,
3233 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003234 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003235 "error_if_validators": (
3236 TosaErrorValidator.evWrongInputType,
3237 TosaErrorValidator.evWrongOutputType,
3238 TosaErrorValidator.evWrongInputList,
3239 TosaErrorValidator.evWrongOutputList,
3240 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003241 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003242 "bitwise_not": {
3243 "op": Op.BITWISE_NOT,
3244 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003245 "build_fcn": (
3246 build_unary,
3247 TosaTensorGen.tgBasic,
3248 TosaTensorValuesGen.tvgDefault,
3249 None,
3250 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003251 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003252 "error_if_validators": (
3253 TosaErrorValidator.evWrongInputType,
3254 TosaErrorValidator.evWrongOutputType,
3255 TosaErrorValidator.evWrongInputList,
3256 TosaErrorValidator.evWrongOutputList,
3257 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003258 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003259 "ceil": {
3260 "op": Op.CEIL,
3261 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003262 "build_fcn": (
3263 build_unary,
3264 TosaTensorGen.tgBasic,
3265 TosaTensorValuesGen.tvgDefault,
3266 None,
3267 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003268 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003269 "error_if_validators": (
3270 TosaErrorValidator.evWrongInputType,
3271 TosaErrorValidator.evWrongOutputType,
3272 TosaErrorValidator.evWrongInputList,
3273 TosaErrorValidator.evWrongOutputList,
3274 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003275 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003276 "clz": {
3277 "op": Op.CLZ,
3278 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003279 "build_fcn": (
3280 build_unary,
3281 TosaTensorGen.tgBasic,
3282 TosaTensorValuesGen.tvgDefault,
3283 None,
3284 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003285 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003286 "error_if_validators": (
3287 TosaErrorValidator.evWrongInputType,
3288 TosaErrorValidator.evWrongOutputType,
3289 TosaErrorValidator.evWrongInputList,
3290 TosaErrorValidator.evWrongOutputList,
3291 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003292 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003293 "exp": {
3294 "op": Op.EXP,
3295 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003296 "build_fcn": (
3297 build_unary,
3298 TosaTensorGen.tgBasic,
3299 TosaTensorValuesGen.tvgDefault,
3300 None,
3301 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003302 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003303 "error_if_validators": (
3304 TosaErrorValidator.evWrongInputType,
3305 TosaErrorValidator.evWrongOutputType,
3306 TosaErrorValidator.evWrongInputList,
3307 TosaErrorValidator.evWrongOutputList,
3308 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003310 "floor": {
3311 "op": Op.FLOOR,
3312 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003313 "build_fcn": (
3314 build_unary,
3315 TosaTensorGen.tgBasic,
3316 TosaTensorValuesGen.tvgDefault,
3317 None,
3318 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003319 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003320 "error_if_validators": (
3321 TosaErrorValidator.evWrongInputType,
3322 TosaErrorValidator.evWrongOutputType,
3323 TosaErrorValidator.evWrongInputList,
3324 TosaErrorValidator.evWrongOutputList,
3325 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003326 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003327 "log": {
3328 "op": Op.LOG,
3329 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003330 "build_fcn": (
3331 build_unary,
3332 TosaTensorGen.tgBasic,
3333 TosaTensorValuesGen.tvgDefault,
3334 None,
3335 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003336 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003337 "error_if_validators": (
3338 TosaErrorValidator.evWrongInputType,
3339 TosaErrorValidator.evWrongOutputType,
3340 TosaErrorValidator.evWrongInputList,
3341 TosaErrorValidator.evWrongOutputList,
3342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003344 "logical_not": {
3345 "op": Op.LOGICAL_NOT,
3346 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347 "build_fcn": (
3348 build_unary,
3349 TosaTensorGen.tgBasic,
3350 TosaTensorValuesGen.tvgDefault,
3351 None,
3352 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003354 "error_if_validators": (
3355 TosaErrorValidator.evWrongInputType,
3356 TosaErrorValidator.evWrongOutputType,
3357 TosaErrorValidator.evWrongInputList,
3358 TosaErrorValidator.evWrongOutputList,
3359 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003361 "negate": {
3362 "op": Op.NEGATE,
3363 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003364 "build_fcn": (
3365 build_unary,
3366 TosaTensorGen.tgBasic,
3367 TosaTensorValuesGen.tvgNegate,
3368 None,
3369 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003370 "qgen": TosaQuantGen.qgUnary,
3371 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003372 "error_if_validators": (
3373 TosaErrorValidator.evInputZeroPointNotZero,
3374 TosaErrorValidator.evOutputZeroPointNotZero,
3375 TosaErrorValidator.evWrongInputType,
3376 TosaErrorValidator.evWrongOutputType,
3377 TosaErrorValidator.evWrongInputList,
3378 TosaErrorValidator.evWrongOutputList,
3379 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003380 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003381 "reciprocal": {
3382 "op": Op.RECIPROCAL,
3383 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003384 "build_fcn": (
3385 build_unary,
3386 TosaTensorGen.tgBasic,
3387 TosaTensorValuesGen.tvgDefault,
3388 None,
3389 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003390 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003391 "error_if_validators": (
3392 TosaErrorValidator.evWrongInputType,
3393 TosaErrorValidator.evWrongOutputType,
3394 TosaErrorValidator.evWrongInputList,
3395 TosaErrorValidator.evWrongOutputList,
3396 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003397 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 "rsqrt": {
3399 "op": Op.RSQRT,
3400 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003401 "build_fcn": (
3402 build_unary,
3403 TosaTensorGen.tgBasic,
3404 TosaTensorValuesGen.tvgDefault,
3405 None,
3406 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003407 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003408 "error_if_validators": (
3409 TosaErrorValidator.evWrongInputType,
3410 TosaErrorValidator.evWrongOutputType,
3411 TosaErrorValidator.evWrongInputList,
3412 TosaErrorValidator.evWrongOutputList,
3413 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003414 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003415 # Elementwise Ternary operators
3416 "select": {
3417 "op": Op.SELECT,
3418 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003419 "build_fcn": (
3420 build_select,
3421 TosaTensorGen.tgBroadcastFuzz,
3422 TosaTensorValuesGen.tvgSelect,
3423 None,
3424 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003425 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003426 "error_if_validators": (
3427 TosaErrorValidator.evRankMismatch,
3428 TosaErrorValidator.evWrongInputType,
3429 TosaErrorValidator.evWrongOutputType,
3430 TosaErrorValidator.evWrongInputList,
3431 TosaErrorValidator.evWrongOutputList,
3432 TosaErrorValidator.evDimensionMismatch,
3433 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003435 # Comparison operators
3436 "equal": {
3437 "op": Op.EQUAL,
3438 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 "build_fcn": (
3440 build_comparison,
3441 TosaTensorGen.tgBroadcastFuzz,
3442 TosaTensorValuesGen.tvgEqual,
3443 None,
3444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003446 "error_if_validators": (
3447 TosaErrorValidator.evRankMismatch,
3448 TosaErrorValidator.evWrongInputType,
3449 TosaErrorValidator.evWrongOutputType,
3450 TosaErrorValidator.evWrongInputList,
3451 TosaErrorValidator.evWrongOutputList,
3452 TosaErrorValidator.evDimensionMismatch,
3453 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003454 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 "greater_equal": {
3456 "op": Op.GREATER_EQUAL,
3457 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003458 "build_fcn": (
3459 build_comparison,
3460 TosaTensorGen.tgBroadcastFuzz,
3461 TosaTensorValuesGen.tvgDefault,
3462 None,
3463 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003464 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003465 "error_if_validators": (
3466 TosaErrorValidator.evRankMismatch,
3467 TosaErrorValidator.evWrongInputType,
3468 TosaErrorValidator.evWrongOutputType,
3469 TosaErrorValidator.evWrongInputList,
3470 TosaErrorValidator.evWrongOutputList,
3471 TosaErrorValidator.evDimensionMismatch,
3472 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003474 "greater": {
3475 "op": Op.GREATER,
3476 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003477 "build_fcn": (
3478 build_comparison,
3479 TosaTensorGen.tgBroadcastFuzz,
3480 TosaTensorValuesGen.tvgDefault,
3481 None,
3482 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003483 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003484 "error_if_validators": (
3485 TosaErrorValidator.evRankMismatch,
3486 TosaErrorValidator.evWrongInputType,
3487 TosaErrorValidator.evWrongOutputType,
3488 TosaErrorValidator.evWrongInputList,
3489 TosaErrorValidator.evWrongOutputList,
3490 TosaErrorValidator.evDimensionMismatch,
3491 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003492 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003493 # Reduction operators
3494 "reduce_all": {
3495 "op": Op.REDUCE_ALL,
3496 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003497 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003498 "build_fcn": (
3499 build_reduce,
3500 TosaTensorGen.tgBasic,
3501 TosaTensorValuesGen.tvgDefault,
3502 TosaArgGen.agAxis,
3503 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003505 "error_if_validators": (
3506 TosaErrorValidator.evAxisLargerRank,
3507 TosaErrorValidator.evAxisSmallerZero,
3508 TosaErrorValidator.evShapeOfAxisNotOne,
3509 TosaErrorValidator.evWrongInputType,
3510 TosaErrorValidator.evWrongOutputType,
3511 TosaErrorValidator.evWrongRank,
3512 TosaErrorValidator.evWrongInputList,
3513 TosaErrorValidator.evWrongOutputList,
3514 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003516 "reduce_any": {
3517 "op": Op.REDUCE_ANY,
3518 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003519 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003520 "build_fcn": (
3521 build_reduce,
3522 TosaTensorGen.tgBasic,
3523 TosaTensorValuesGen.tvgDefault,
3524 TosaArgGen.agAxis,
3525 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003527 "error_if_validators": (
3528 TosaErrorValidator.evAxisLargerRank,
3529 TosaErrorValidator.evAxisSmallerZero,
3530 TosaErrorValidator.evShapeOfAxisNotOne,
3531 TosaErrorValidator.evWrongInputType,
3532 TosaErrorValidator.evWrongOutputType,
3533 TosaErrorValidator.evWrongRank,
3534 TosaErrorValidator.evWrongInputList,
3535 TosaErrorValidator.evWrongOutputList,
3536 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003537 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003538 "reduce_max": {
3539 "op": Op.REDUCE_MAX,
3540 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003541 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003542 "build_fcn": (
3543 build_reduce,
3544 TosaTensorGen.tgBasic,
3545 TosaTensorValuesGen.tvgDefault,
3546 TosaArgGen.agAxis,
3547 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003549 "error_if_validators": (
3550 TosaErrorValidator.evAxisLargerRank,
3551 TosaErrorValidator.evAxisSmallerZero,
3552 TosaErrorValidator.evShapeOfAxisNotOne,
3553 TosaErrorValidator.evWrongInputType,
3554 TosaErrorValidator.evWrongOutputType,
3555 TosaErrorValidator.evWrongRank,
3556 TosaErrorValidator.evWrongInputList,
3557 TosaErrorValidator.evWrongOutputList,
3558 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003560 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003561 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003562 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003563 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003564 "build_fcn": (
3565 build_reduce,
3566 TosaTensorGen.tgBasic,
3567 TosaTensorValuesGen.tvgDefault,
3568 TosaArgGen.agAxis,
3569 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003570 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003571 "error_if_validators": (
3572 TosaErrorValidator.evAxisLargerRank,
3573 TosaErrorValidator.evAxisSmallerZero,
3574 TosaErrorValidator.evShapeOfAxisNotOne,
3575 TosaErrorValidator.evWrongInputType,
3576 TosaErrorValidator.evWrongOutputType,
3577 TosaErrorValidator.evWrongRank,
3578 TosaErrorValidator.evWrongInputList,
3579 TosaErrorValidator.evWrongOutputList,
3580 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 "reduce_product": {
3583 "op": Op.REDUCE_PRODUCT,
3584 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003585 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003586 "build_fcn": (
3587 build_reduce,
3588 TosaTensorGen.tgBasic,
3589 TosaTensorValuesGen.tvgDefault,
3590 TosaArgGen.agAxis,
3591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003592 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003593 "error_if_validators": (
3594 TosaErrorValidator.evAxisLargerRank,
3595 TosaErrorValidator.evAxisSmallerZero,
3596 TosaErrorValidator.evShapeOfAxisNotOne,
3597 TosaErrorValidator.evWrongInputType,
3598 TosaErrorValidator.evWrongOutputType,
3599 TosaErrorValidator.evWrongRank,
3600 TosaErrorValidator.evWrongInputList,
3601 TosaErrorValidator.evWrongOutputList,
3602 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003603 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 "reduce_sum": {
3605 "op": Op.REDUCE_SUM,
3606 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003607 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003608 "build_fcn": (
3609 build_reduce,
3610 TosaTensorGen.tgBasic,
3611 TosaTensorValuesGen.tvgReduceSum,
3612 TosaArgGen.agAxis,
3613 ),
James Ward24dbc422022-10-19 12:20:31 +01003614 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003615 "error_if_validators": (
3616 TosaErrorValidator.evAxisLargerRank,
3617 TosaErrorValidator.evAxisSmallerZero,
3618 TosaErrorValidator.evShapeOfAxisNotOne,
3619 TosaErrorValidator.evWrongInputType,
3620 TosaErrorValidator.evWrongOutputType,
3621 TosaErrorValidator.evWrongRank,
3622 TosaErrorValidator.evWrongInputList,
3623 TosaErrorValidator.evWrongOutputList,
3624 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003625 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003626 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003627 "concat": {
3628 "op": Op.CONCAT,
3629 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003630 "build_fcn": (
3631 build_concat,
3632 TosaTensorGen.tgConcat,
3633 TosaTensorValuesGen.tvgConcat,
3634 TosaArgGen.agAxis,
3635 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003636 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 "error_if_validators": (
3638 TosaErrorValidator.evAxisLargerRank,
3639 TosaErrorValidator.evAxisSmallerZero,
3640 TosaErrorValidator.evConcatInputRankMismatch,
3641 TosaErrorValidator.evConcatShapeSumMismatch,
3642 TosaErrorValidator.evConcatInputDimMismatch,
3643 TosaErrorValidator.evWrongInputType,
3644 TosaErrorValidator.evWrongOutputType,
3645 TosaErrorValidator.evWrongOutputList,
3646 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003647 },
3648 "pad": {
3649 "op": Op.PAD,
3650 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003651 "build_fcn": (
3652 build_pad,
3653 TosaTensorGen.tgBasic,
3654 TosaTensorValuesGen.tvgDefault,
3655 TosaArgGen.agPad,
3656 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003657 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003658 "error_if_validators": (
3659 TosaErrorValidator.evWrongInputType,
3660 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003661 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003662 TosaErrorValidator.evWrongOutputType,
3663 TosaErrorValidator.evWrongInputList,
3664 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003665 TosaErrorValidator.evRankMismatch,
3666 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003667 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003668 },
3669 "reshape": {
3670 "op": Op.RESHAPE,
3671 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 "build_fcn": (
3673 build_reshape,
3674 TosaTensorGen.tgBasic,
3675 TosaTensorValuesGen.tvgDefault,
3676 TosaArgGen.agReshape,
3677 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003678 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 "error_if_validators": (
3680 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003686 },
3687 "reverse": {
3688 "op": Op.REVERSE,
3689 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003690 "build_fcn": (
3691 build_reverse,
3692 TosaTensorGen.tgBasic,
3693 TosaTensorValuesGen.tvgDefault,
3694 TosaArgGen.agAxis,
3695 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003696 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003697 "error_if_validators": (
3698 TosaErrorValidator.evAxisSmallerZero,
3699 TosaErrorValidator.evAxisLargerRank,
3700 TosaErrorValidator.evWrongInputType,
3701 TosaErrorValidator.evWrongOutputType,
3702 TosaErrorValidator.evWrongInputList,
3703 TosaErrorValidator.evWrongOutputList,
3704 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003705 },
3706 "slice": {
3707 "op": Op.SLICE,
3708 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003709 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003710 "build_fcn": (
3711 build_slice,
3712 TosaTensorGen.tgBasic,
3713 TosaTensorValuesGen.tvgDefault,
3714 TosaArgGen.agSlice,
3715 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003716 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003717 "error_if_validators": (
3718 TosaErrorValidator.evStartSmallerZero,
3719 TosaErrorValidator.evSizeSmallerEqualZero,
3720 TosaErrorValidator.evStartSizeOutsideBounds,
3721 TosaErrorValidator.evSizeOutputShapeMismatch,
3722 TosaErrorValidator.evInputSizeStartLengthMismatch,
3723 TosaErrorValidator.evWrongRank,
3724 TosaErrorValidator.evWrongInputType,
3725 TosaErrorValidator.evWrongOutputType,
3726 TosaErrorValidator.evWrongInputList,
3727 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003728 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003729 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003730 },
3731 "tile": {
3732 "op": Op.TILE,
3733 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003734 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003735 "build_fcn": (
3736 build_tile,
3737 TosaTensorGen.tgBasic,
3738 TosaTensorValuesGen.tvgDefault,
3739 TosaArgGen.agTile,
3740 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003741 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 "error_if_validators": (
3743 TosaErrorValidator.evWrongInputType,
3744 TosaErrorValidator.evWrongOutputType,
3745 TosaErrorValidator.evWrongInputList,
3746 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003747 TosaErrorValidator.evRankMismatch,
3748 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003749 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003750 },
3751 "transpose": {
3752 "op": Op.TRANSPOSE,
3753 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003754 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003755 "build_fcn": (
3756 build_transpose,
3757 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003758 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003759 TosaArgGen.agTranspose,
3760 ),
3761 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 "error_if_validators": (
3763 TosaErrorValidator.evIndexOutsideBounds,
3764 TosaErrorValidator.evIndexUsedTwice,
3765 TosaErrorValidator.evWrongInputType,
3766 TosaErrorValidator.evWrongOutputType,
3767 TosaErrorValidator.evWrongInputList,
3768 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003769 TosaErrorValidator.evWrongRank,
3770 TosaErrorValidator.evRankMismatch,
3771 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003772 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003773 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003774 # Data nodes
3775 "const": {
3776 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003777 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003778 "build_fcn": (
3779 build_const,
3780 TosaTensorGen.tgBasic,
3781 TosaTensorValuesGen.tvgDefault,
3782 None,
3783 ),
Luke Hutton65872422023-02-20 10:33:04 +00003784 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08003785 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 "identity": {
3787 "op": Op.IDENTITY,
3788 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 "build_fcn": (
3790 build_unary,
3791 TosaTensorGen.tgBasic,
3792 TosaTensorValuesGen.tvgDefault,
3793 None,
3794 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003795 "types": TYPE_FIB,
3796 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003797 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003798 "gather": {
3799 "op": Op.GATHER,
3800 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3801 "operands": (1, 0),
3802 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003803 "build_fcn": (
3804 build_gather,
3805 TosaTensorGen.tgBasic,
3806 TosaTensorValuesGen.tvgDefault,
3807 None,
3808 ),
James Ward24dbc422022-10-19 12:20:31 +01003809 "types": (
3810 DType.INT8,
3811 DType.INT16,
3812 DType.INT32,
3813 DType.FP16,
3814 DType.BF16,
3815 DType.FP32,
3816 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003817 "error_if_validators": (
3818 TosaErrorValidator.evWrongInputType,
3819 TosaErrorValidator.evWrongOutputType,
3820 TosaErrorValidator.evWrongInputList,
3821 TosaErrorValidator.evWrongOutputList,
3822 TosaErrorValidator.evWrongRank,
3823 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003824 },
3825 "scatter": {
3826 "op": Op.SCATTER,
3827 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003828 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003829 "operands": (2, 0),
3830 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003831 "build_fcn": (
3832 build_scatter,
3833 TosaTensorGen.tgScatter,
3834 TosaTensorValuesGen.tvgDefault,
3835 None,
3836 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003837 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003838 "error_if_validators": (
3839 TosaErrorValidator.evWrongInputType,
3840 TosaErrorValidator.evWrongOutputType,
3841 TosaErrorValidator.evWrongInputList,
3842 TosaErrorValidator.evWrongOutputList,
3843 TosaErrorValidator.evWrongRank,
3844 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003845 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003846 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003847 "resize": {
3848 "op": Op.RESIZE,
3849 "operands": (1, 0),
3850 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003851 "build_fcn": (
3852 build_resize,
3853 TosaTensorGen.tgNHWC,
3854 TosaTensorValuesGen.tvgDefault,
3855 TosaArgGen.agResize,
3856 ),
James Ward24dbc422022-10-19 12:20:31 +01003857 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003858 "invalid_test_validators": (
3859 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003860 ),
3861 "error_if_validators": (
3862 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003863 TosaErrorValidator.evScaleSmallerEqualZero,
3864 TosaErrorValidator.evScaleNLargerMax,
3865 TosaErrorValidator.evScaleDLargerMax,
3866 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003867 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003868 TosaErrorValidator.evBorderSmallerMin,
3869 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003870 TosaErrorValidator.evWrongInputType,
3871 TosaErrorValidator.evWrongOutputType,
3872 TosaErrorValidator.evWrongRank,
3873 TosaErrorValidator.evWrongInputList,
3874 TosaErrorValidator.evWrongOutputList,
3875 TosaErrorValidator.evBatchMismatch,
3876 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003877 TosaErrorValidator.evResizeOutputShapeMismatch,
3878 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003879 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003880 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003881 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003882 "cast": {
3883 "op": Op.CAST,
3884 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003885 "build_fcn": (
3886 build_cast,
3887 TosaTensorGen.tgBasic,
3888 TosaTensorValuesGen.tvgDefault,
3889 TosaArgGen.agCast,
3890 ),
James Ward8b390432022-08-12 20:48:56 +01003891 "types": (
3892 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003893 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003894 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003895 DType.INT8,
3896 DType.INT16,
3897 DType.INT32,
3898 DType.BOOL,
3899 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003900 "error_if_validators": (
3901 TosaErrorValidator.evWrongInputType,
3902 TosaErrorValidator.evWrongOutputType,
3903 TosaErrorValidator.evWrongInputList,
3904 TosaErrorValidator.evWrongOutputList,
3905 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003906 },
3907 "rescale": {
3908 "op": Op.RESCALE,
3909 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003910 "build_fcn": (
3911 build_rescale,
3912 TosaTensorGen.tgBasic,
3913 TosaTensorValuesGen.tvgDefault,
3914 TosaArgGen.agRescale,
3915 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003916 "types": [
3917 DType.UINT8,
3918 DType.INT8,
3919 DType.INT16,
3920 DType.INT32,
3921 DType.INT48,
3922 DType.UINT16,
3923 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003924 "error_if_validators": (
3925 TosaErrorValidator.evInputZeroPointNotZero,
3926 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003927 TosaErrorValidator.evU16InputZeroPointNotValid,
3928 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003929 TosaErrorValidator.evScaleTrue,
3930 TosaErrorValidator.evScaleNotTrue,
3931 TosaErrorValidator.evWrongInputType,
3932 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003933 TosaErrorValidator.evWrongInputList,
3934 TosaErrorValidator.evWrongOutputList,
3935 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003936 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003937 # Custom
3938 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003939 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003940 # Two varients of cond_if, one that generates one of two constant tensors (no
3941 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3942 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003943 "cond_if_const": {
3944 "op": Op.COND_IF,
3945 "operands": (0, 2),
3946 "build_fcn": (
3947 build_cond_if_const,
3948 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003949 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003950 TosaArgGen.agCondIf,
3951 ),
3952 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003953 "error_if_validators": (
3954 TosaErrorValidator.evOutputListThenGraphMismatch,
3955 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003956 TosaErrorValidator.evCondIfCondNotMatchingBool,
3957 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003958 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003959 },
3960 "cond_if_binary": {
3961 "op": Op.COND_IF,
3962 "operands": (2, 0),
3963 "build_fcn": (
3964 build_cond_if_binary,
3965 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003966 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003967 TosaArgGen.agCondIf,
3968 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003969 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003970 "error_if_validators": (
3971 TosaErrorValidator.evInputListThenGraphMismatch,
3972 TosaErrorValidator.evInputListElseGraphMismatch,
3973 TosaErrorValidator.evOutputListThenGraphMismatch,
3974 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003975 TosaErrorValidator.evCondIfCondNotMatchingBool,
3976 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003977 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003978 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003979 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003980 "while_loop": {
3981 "op": Op.WHILE_LOOP,
3982 "operands": (0, 1),
3983 "build_fcn": (
3984 build_while_loop,
3985 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003986 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003987 TosaArgGen.agWhileLoop,
3988 ),
3989 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003990 "error_if_validators": (
3991 TosaErrorValidator.evInputListOutputListMismatch,
3992 TosaErrorValidator.evInputListCondGraphMismatch,
3993 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3994 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3995 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003996 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003997 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003998 },
Luke Hutton57287132023-02-06 14:54:18 +00003999 "fft2d": {
4000 "op": Op.FFT2D,
4001 "operands": (2, 0),
4002 "rank": (3, 3),
4003 "build_fcn": (
4004 build_fft2d,
4005 TosaTensorGen.tgFFT2d,
4006 TosaTensorValuesGen.tvgDefault,
4007 TosaArgGen.agFFT2d,
4008 ),
4009 "types": [DType.FP32],
4010 "error_if_validators": (
4011 TosaErrorValidator.evWrongInputType,
4012 TosaErrorValidator.evWrongOutputType,
4013 TosaErrorValidator.evWrongInputList,
4014 TosaErrorValidator.evWrongOutputList,
4015 TosaErrorValidator.evWrongRank,
4016 TosaErrorValidator.evBatchMismatch,
4017 TosaErrorValidator.evKernelNotPowerOfTwo,
4018 TosaErrorValidator.evFFTInputShapeMismatch,
4019 TosaErrorValidator.evFFTOutputShapeMismatch,
4020 ),
4021 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004022 "rfft2d": {
4023 "op": Op.RFFT2D,
4024 "operands": (1, 0),
4025 "rank": (3, 3),
4026 "build_fcn": (
4027 build_rfft2d,
4028 TosaTensorGen.tgRFFT2d,
4029 TosaTensorValuesGen.tvgDefault,
4030 TosaArgGen.agNone,
4031 ),
4032 "types": [DType.FP32],
4033 "error_if_validators": (
4034 TosaErrorValidator.evWrongInputType,
4035 TosaErrorValidator.evWrongOutputType,
4036 TosaErrorValidator.evWrongInputList,
4037 TosaErrorValidator.evWrongOutputList,
4038 TosaErrorValidator.evWrongRank,
4039 TosaErrorValidator.evBatchMismatch,
4040 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004041 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004042 ),
4043 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004044 }
4045
Kevin Cheng550ccc52021-03-03 11:21:43 -08004046
Eric Kunzee5e26762020-10-13 16:11:07 -07004047class OutputShaper:
4048 # Methods in this class compute the expected output shape and datatype
4049 # for common classes of operations
4050 def __init__(self):
4051 pass
4052
4053 # These methods return arguments that can be used for
4054 # creating a new output tensor
4055 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004056 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4057 if error_name != ErrorIf.RankMismatch:
4058 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004059 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004060
4061 shape = []
4062 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004063 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004064 shape.append(b.shape[i])
4065 else:
4066 shape.append(a.shape[i])
4067
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004068 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004069 all_dtypes = [
4070 DType.INT8,
4071 DType.INT16,
4072 DType.INT32,
4073 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004074 DType.FP16,
4075 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004076 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004077 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004078 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4079 outputDType = rng.choice(wrong_dtypes)
4080 else:
4081 outputDType = a.dtype
4082
4083 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004084
4085 @staticmethod
4086 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004087 assert len(a.shape) == len(b.shape)
4088 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004089
4090 shape = []
4091 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004092 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004093 shape.append(a.shape[i])
4094
Kevin Cheng550ccc52021-03-03 11:21:43 -08004095 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004096
4097 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004098 def unaryOp(ser, rng, a, error_name=None):
4099 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004100 all_dtypes = [
4101 DType.INT8,
4102 DType.INT16,
4103 DType.INT32,
4104 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004105 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004106 DType.FP16,
4107 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004108 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004109 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4110 outputDType = rng.choice(wrong_dtypes)
4111 else:
4112 outputDType = a.dtype
4113
4114 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004115
4116 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004117 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004118 if error_name != ErrorIf.RankMismatch:
4119 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004120 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004121
4122 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004123 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004124 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004125 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4126 else:
4127 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004128
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004129 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004130 all_dtypes = [
4131 DType.INT8,
4132 DType.INT16,
4133 DType.INT32,
4134 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004135 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004136 DType.FP16,
4137 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004138 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004139 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4140 outputDType = rng.choice(wrong_dtypes)
4141 else:
4142 outputDType = a.dtype
4143
4144 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004145
4146 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004147 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004148 if error_name != ErrorIf.RankMismatch:
4149 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004150 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004151
4152 # Do broadcast
4153 shape = []
4154 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004155 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004156 shape.append(b.shape[i])
4157 else:
4158 shape.append(a.shape[i])
4159
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004160 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004161 wrong_dtypes = [
4162 DType.INT8,
4163 DType.INT16,
4164 DType.INT32,
4165 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004166 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004167 DType.FP16,
4168 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004169 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004170 outputDType = rng.choice(wrong_dtypes)
4171 else:
4172 outputDType = DType.BOOL
4173
4174 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004175
4176 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004177 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004178 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004179 if error_name not in [
4180 ErrorIf.AxisSmallerZero,
4181 ErrorIf.AxisLargerRank,
4182 ErrorIf.ShapeOfAxisNotOne,
4183 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004184 shape[axis] = 1
4185 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4186 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004187
Matthew Haddond6ce7252021-09-29 15:35:44 +01004188 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004189 all_dtypes = [
4190 DType.INT8,
4191 DType.INT16,
4192 DType.INT32,
4193 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004194 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004195 DType.FP16,
4196 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004197 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004198 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4199 outputDType = rng.choice(wrong_dtypes)
4200 else:
4201 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004202
Matthew Haddond6ce7252021-09-29 15:35:44 +01004203 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004204
4205 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004206 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004207 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004208
4209 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4210 del shape[axis]
4211
4212 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4213 remove = rng.choice([True, False])
4214 if remove and len(shape) > 1:
4215 del shape[0]
4216 else:
4217 shape.append(1)
4218 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4219 for i in range(len(shape)):
4220 shape[i] = shape[i] + rng.integers(1, 10)
4221
4222 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004223 all_dtypes = [
4224 DType.INT8,
4225 DType.INT16,
4226 DType.INT32,
4227 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004228 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004229 DType.FP16,
4230 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004231 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004232 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4233 outputDType = rng.choice(wrong_dtypes)
4234 else:
4235 outputDType = DType.INT32
4236
4237 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004238
4239 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004240 def conv2dOp(
4241 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4242 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004243
4244 # IFM: NHWC
4245 # Filter: OHWI
4246 # OFM: NHWC
4247
Kevin Cheng550ccc52021-03-03 11:21:43 -08004248 h = (
4249 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004250 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 + padding[0]
4252 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004253 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004254 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004255
Kevin Cheng550ccc52021-03-03 11:21:43 -08004256 w = (
4257 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004258 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004259 + padding[2]
4260 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004261 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004262 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004263
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004264 if error_name == ErrorIf.ConvOutputShapeMismatch:
4265 choices = [1, 2, 3]
4266 change = rng.choice(choices)
4267 # increment in multiples of stride to not hit non-integer error case
4268 if change in [1, 3]:
4269 h = h + (rng.choice(choices) * strides[0])
4270 if change in [2, 3]:
4271 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004272
Eric Kunzee5e26762020-10-13 16:11:07 -07004273 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4274
James Ward8b390432022-08-12 20:48:56 +01004275 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004276 # Pick some potentially correct output dtype if input type is incorrect
4277 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004278 else:
James Ward8b390432022-08-12 20:48:56 +01004279 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004280
4281 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004282 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004283 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004284 else:
4285 excludes = [out_dtype]
4286 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004287 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004288
Kevin Cheng550ccc52021-03-03 11:21:43 -08004289 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004290
4291 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004292 def conv3dOp(
4293 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4294 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004295
4296 # IFM: NDHWC
4297 # Filter: ODHWI
4298 # OFM: NDHWC
4299
4300 d = (
4301 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004302 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004303 + padding[0]
4304 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004305 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004306 ) // strides[0] + 1
4307
4308 h = (
4309 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004310 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004311 + padding[2]
4312 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004313 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004314 ) // strides[1] + 1
4315
4316 w = (
4317 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004318 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004319 + padding[4]
4320 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004321 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004322 ) // strides[2] + 1
4323
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004324 if error_name == ErrorIf.ConvOutputShapeMismatch:
4325 choices = [1, 2, 3, 4]
4326 change = rng.choice(choices)
4327 # increment in multiples of stride to not hit non-integer error case
4328 if change in [1, 4]:
4329 d = d + (rng.choice(choices) * strides[0])
4330 if change in [2, 4]:
4331 h = h + (rng.choice(choices) * strides[1])
4332 if change in [3, 4]:
4333 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004334
Kevin Cheng1533b852021-09-01 12:51:58 -07004335 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4336
James Ward8b390432022-08-12 20:48:56 +01004337 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004338 # Pick some potentially correct output dtype if input type is incorrect
4339 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004340 else:
James Ward8b390432022-08-12 20:48:56 +01004341 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004342
4343 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004344 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004345 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004346 else:
4347 excludes = [out_dtype]
4348 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004349 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004350
4351 return ser.addOutput(ofm_shape, out_dtype)
4352
4353 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004354 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004355 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004356 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004357 # IFM: NHWC
4358 # Filter: HWCM
4359 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004360
Kevin Cheng550ccc52021-03-03 11:21:43 -08004361 h = (
4362 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004363 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004364 + padding[0]
4365 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004366 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004367 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004368
Kevin Cheng550ccc52021-03-03 11:21:43 -08004369 w = (
4370 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004371 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004372 + padding[2]
4373 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004374 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004375 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004376
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004377 if error_name == ErrorIf.ConvOutputShapeMismatch:
4378 choices = [1, 2, 3]
4379 change = rng.choice(choices)
4380 # increment in multiples of stride to not hit non-integer error case
4381 if change in [1, 3]:
4382 h = h + (rng.choice(choices) * strides[0])
4383 if change in [2, 3]:
4384 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004385
Eric Kunzee5e26762020-10-13 16:11:07 -07004386 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4387
James Ward8b390432022-08-12 20:48:56 +01004388 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004389 # Pick some potentially correct output dtype if input type is incorrect
4390 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004391 else:
James Ward8b390432022-08-12 20:48:56 +01004392 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004393
4394 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004395 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004396 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004397 else:
4398 excludes = [out_dtype]
4399 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004400 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004401
Kevin Cheng550ccc52021-03-03 11:21:43 -08004402 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004403
4404 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004405 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004406 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004407 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004408 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004409 h = 1
4410 w = 1
4411 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004412 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4413 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004414
4415 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004416 choices = [1, 2, 3]
4417 change = rng.choice(choices)
4418 # increment in multiples of stride to not hit non-integer error case
4419 if change in [1, 3]:
4420 h = h + (rng.choice(choices) * stride[0])
4421 if change in [2, 3]:
4422 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004423 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004424
4425 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004426 all_dtypes = [
4427 DType.INT8,
4428 DType.INT16,
4429 DType.INT32,
4430 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004431 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004432 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004433 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004434 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004435 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4436 outputDType = rng.choice(wrong_dtypes)
4437 else:
4438 outputDType = ifm.dtype
4439
4440 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004441
4442 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004443 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004444 # input: N, IC
4445 # filter: OC, IC
4446 # output: N, OC
4447
4448 output_shape = [input.shape[0], filter.shape[0]]
4449
James Ward8b390432022-08-12 20:48:56 +01004450 # Validated in arg_gen (also invalidated for ErrorIf)
4451 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004452
Kevin Cheng550ccc52021-03-03 11:21:43 -08004453 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004454
4455 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004456 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004457 # a: N, H, C
4458 # b: N, C, W
4459 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004460
Kevin Cheng2d60f002021-06-09 14:18:32 -07004461 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004462
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004463 if error_name == ErrorIf.WrongOutputType:
4464 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004465 incorrect_types = (
4466 DType.INT4,
4467 DType.INT8,
4468 DType.INT16,
4469 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004470 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004471 DType.FP16,
4472 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004473 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004474 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004475 incorrect_types = (
4476 DType.INT4,
4477 DType.INT8,
4478 DType.INT16,
4479 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004480 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004481 DType.FP16,
4482 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004483 )
James Ward24dbc422022-10-19 12:20:31 +01004484 elif (
4485 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4486 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004487 incorrect_types = (
4488 DType.INT4,
4489 DType.INT8,
4490 DType.INT16,
4491 DType.INT32,
4492 DType.INT48,
4493 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004494 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004495 elif error_name == ErrorIf.WrongInputType:
4496 # Pick some potentially correct output dtype if input type is incorrect
4497 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004498 else:
James Ward8b390432022-08-12 20:48:56 +01004499 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004500
Kevin Cheng550ccc52021-03-03 11:21:43 -08004501 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004502
4503 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004504 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004505 input1 = a[0]
4506 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004507
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004508 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004509 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004510 if not (
4511 # unable to concat tensors of different ranks
4512 error_name == ErrorIf.ConcatInputRankMismatch
4513 # unable to concat tensors along an invalid axis
4514 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004515 ):
4516 for tensor in remaining_inputs:
4517 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004518
Matthew Haddon01c359d2021-10-15 16:30:48 +01004519 if error_name == ErrorIf.ConcatShapeSumMismatch:
4520 output_shape[axis] += rng.integers(5, 10)
4521
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004522 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004523 all_dtypes = {
4524 DType.INT8,
4525 DType.INT16,
4526 DType.INT32,
4527 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004528 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004529 DType.FP16,
4530 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004531 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004532 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4533 outputDType = rng.choice(wrong_dtypes)
4534 else:
4535 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004536
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004537 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004538
4539 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004540 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004541
4542 output_shape = a.shape.copy()
4543
4544 for i in range(len(output_shape)):
4545 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4546
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004547 if error_name == ErrorIf.PadOutputShapeMismatch:
4548 bad_dim = rng.choice(range(len(output_shape)))
4549 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004550 elif error_name == ErrorIf.RankMismatch:
4551 output_shape = get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004552
Matthew Haddone807aae2021-10-11 18:12:58 +01004553 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004554 all_dtypes = [
4555 DType.INT8,
4556 DType.INT16,
4557 DType.INT32,
4558 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004559 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004560 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004561 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004562 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004563 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4564 outputDType = rng.choice(wrong_dtypes)
4565 else:
4566 outputDType = a.dtype
4567
4568 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004569
4570 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004571 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004572 output_shape = shape.copy()
4573
Matthew Haddone807aae2021-10-11 18:12:58 +01004574 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4575 for i in range(len(output_shape)):
4576 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4577
4578 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004579 all_dtypes = [
4580 DType.INT8,
4581 DType.INT16,
4582 DType.INT32,
4583 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004584 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004585 DType.FP16,
4586 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004587 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004588 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4589 outputDType = rng.choice(wrong_dtypes)
4590 else:
4591 outputDType = a.dtype
4592
4593 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004594
4595 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004596 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004597
Matthew Haddone807aae2021-10-11 18:12:58 +01004598 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004599 all_dtypes = [
4600 DType.INT8,
4601 DType.INT16,
4602 DType.INT32,
4603 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004604 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004605 DType.FP16,
4606 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004607 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004608 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004609 outputDType = rng.choice(wrong_dtypes)
4610 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004611 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004612
Luke Huttona4e48ca2023-02-22 11:53:48 +00004613 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004614 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004615 for index in range(len(output_shape)):
4616 if output_shape[index] <= 2:
4617 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4618 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004619 output_shape[index] = output_shape[index] + rng.choice(
4620 [-2, -1, 1, 2]
4621 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004622 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4623 output_shape = input.shape.copy()
4624 elif error_name == ErrorIf.RankMismatch:
4625 output_shape = get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004626
4627 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004628
4629 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004630 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004631
4632 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004633 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004634
4635 for i in range(len(output_shape)):
4636 output_shape[i] = a.shape[i] * multiples[i]
4637
Luke Huttona4e48ca2023-02-22 11:53:48 +00004638 if error_name == ErrorIf.RankMismatch:
4639 output_shape = get_rank_mismatch_shape(rng, output_shape)
4640
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004641 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004642 all_dtypes = [
4643 DType.INT8,
4644 DType.INT16,
4645 DType.INT32,
4646 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004647 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004648 DType.FP16,
4649 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004650 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004651 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4652 outputDType = rng.choice(wrong_dtypes)
4653 else:
4654 outputDType = a.dtype
4655
4656 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004657
4658 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004659 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004660 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004661
Kevin Cheng550ccc52021-03-03 11:21:43 -08004662 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004663
Luke Huttona4e48ca2023-02-22 11:53:48 +00004664 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004665 for i in range(len(output_shape)):
4666 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004667
Luke Huttona4e48ca2023-02-22 11:53:48 +00004668 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4669 for i in range(len(output_shape)):
4670 output_shape[i] += rng.integers(1, 10)
4671 elif error_name == ErrorIf.RankMismatch:
4672 output_shape = get_rank_mismatch_shape(rng, output_shape)
4673
Matthew Haddone807aae2021-10-11 18:12:58 +01004674 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004675 all_dtypes = [
4676 DType.INT8,
4677 DType.INT16,
4678 DType.INT32,
4679 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004680 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004681 DType.FP16,
4682 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004683 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004684 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4685 outputDType = rng.choice(wrong_dtypes)
4686 else:
4687 outputDType = a.dtype
4688
4689 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004690
4691 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004692 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004693 if error_name != ErrorIf.WrongRank:
4694 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004695 assert len(indices.shape) == 2
4696 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004697
Kevin Cheng77d0f762020-11-24 10:26:32 -08004698 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4699
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004700 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004701 all_dtypes = [
4702 DType.INT8,
4703 DType.INT16,
4704 DType.INT32,
4705 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004706 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004707 DType.FP16,
4708 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004709 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004710 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4711 outputDType = rng.choice(wrong_dtypes)
4712 else:
4713 outputDType = values.dtype
4714
4715 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004716
4717 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004718 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004719 if error_name != ErrorIf.WrongRank:
4720 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004721 assert len(indices.shape) == 2
4722 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004723 assert values_in.shape[0] == indices.shape[0] # N
4724 assert input.shape[1] == indices.shape[1] # W
4725 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004726
4727 output_shape = values_in.shape
4728
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004729 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004730 all_dtypes = [
4731 DType.INT8,
4732 DType.INT16,
4733 DType.INT32,
4734 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004735 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004736 DType.FP16,
4737 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004738 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004739 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4740 outputDType = rng.choice(wrong_dtypes)
4741 else:
4742 outputDType = values_in.dtype
4743
4744 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004745
4746 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004747 def tableOp(ser, rng, input, error_name=None):
4748 # Same shape as the input, dtype dependent on input dtype
4749 if error_name != ErrorIf.WrongInputType:
4750 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004751 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004752 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004753 wrong_dtypes = [
4754 DType.INT8,
4755 DType.INT16,
4756 DType.INT32,
4757 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004758 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004759 DType.FP16,
4760 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004761 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004762 wrong_dtypes.remove(output_dtype)
4763 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004764 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004765
4766 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004767 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004768 serializer,
4769 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004770 input,
4771 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004772 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004773 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004774 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004775 input_dtype,
4776 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004777 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004778 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004779 # Calculate OH, OW
4780 scale_y_n = scale[0]
4781 scale_y_d = scale[1]
4782 scale_x_n = scale[2]
4783 scale_x_d = scale[3]
4784 if error_name == ErrorIf.ScaleSmallerEqualZero:
4785 scale_y_n = max(scale_y_n, 1)
4786 scale_y_d = max(scale_y_d, 1)
4787 scale_x_n = max(scale_x_n, 1)
4788 scale_x_d = max(scale_x_d, 1)
4789
4790 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4791 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4792
4793 if error_name is not None:
4794 # Make sure the output tensor is valid, which can occur when
4795 # scale, offset or border have been changed for ERROR_IFs
4796 oh = max(oh, 1)
4797 ow = max(ow, 1)
4798 if error_name != ErrorIf.MaxDimExceeded:
4799 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4800 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4801
4802 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4803 choices = [1, 2, 3]
4804 change = rng.choice(choices)
4805 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4806 if change in [1, 3]:
4807 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4808 oh -= scale_y_d
4809 assert oh > 0 # Should have been caught in agResize
4810 else:
4811 oh += scale_y_d
4812 if change in [2, 3]:
4813 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4814 ow -= scale_x_d
4815 assert ow > 0 # Should have been caught in agResize
4816 else:
4817 ow += scale_x_d
4818
Matthew Haddon848efb42021-09-09 12:30:53 +01004819 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004820 output_dims = [
4821 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004822 oh,
4823 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004824 input.shape[0],
4825 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004826 elif error_name == ErrorIf.BatchMismatch:
4827 output_dims = [
4828 input.shape[0] + rng.integers(1, 10),
4829 oh,
4830 ow,
4831 input.shape[3],
4832 ]
4833 elif error_name == ErrorIf.ChannelMismatch:
4834 output_dims = [
4835 input.shape[0],
4836 oh,
4837 ow,
4838 input.shape[3] + rng.integers(1, 10),
4839 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004840 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004841 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004842
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004843 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004844
4845 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004846 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004847 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004848
4849 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004850 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004851 if error_name == ErrorIf.ConvOutputShapeMismatch:
4852 choices = [1, 2, 3]
4853 change = rng.choice(choices)
4854 if change in [1, 3]:
4855 output_shape[1] = output_shape[1] + rng.choice(choices)
4856 if change in [2, 3]:
4857 output_shape[2] = output_shape[2] + rng.choice(choices)
4858
James Ward8b390432022-08-12 20:48:56 +01004859 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004860 # Pick some potentially correct output dtype if input type is incorrect
4861 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004862 else:
James Ward8b390432022-08-12 20:48:56 +01004863 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004864
4865 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004866 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004867 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004868 else:
4869 excludes = [out_dtype]
4870 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004871 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004872
Kevin Cheng550ccc52021-03-03 11:21:43 -08004873 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00004874
4875 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00004876 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
4877 outputs = []
4878
4879 assert ifm1.dtype == ifm2.dtype
4880 input_dtype = ifm1.dtype
4881
4882 if error_name != ErrorIf.FFTInputShapeMismatch:
4883 assert ifm1.shape == ifm2.shape
4884
4885 input_shape = ifm1.shape
4886 if error_name != ErrorIf.WrongRank:
4887 assert len(input_shape) == 3
4888
4889 output_shape = input_shape.copy()
4890 output_dtype = input_dtype
4891
4892 if error_name == ErrorIf.WrongOutputType:
4893 excludes = [DType.FP32]
4894 wrong_dtypes = list(usableDTypes(excludes=excludes))
4895 output_dtype = rng.choice(wrong_dtypes)
4896 elif error_name == ErrorIf.BatchMismatch:
4897 output_shape[0] += rng.integers(1, 10)
4898 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4899 modify_dim = rng.choice([1, 2])
4900 output_shape[modify_dim] += rng.integers(1, 10)
4901
4902 outputs.append(serializer.addOutput(output_shape, output_dtype))
4903 outputs.append(serializer.addOutput(output_shape, output_dtype))
4904 return outputs
4905
4906 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00004907 def rfft2dOp(serializer, rng, value, error_name=None):
4908 outputs = []
4909
4910 input_shape = value.shape
4911 if error_name != ErrorIf.WrongRank:
4912 assert len(input_shape) == 3
4913
4914 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
4915
4916 output_dtype = value.dtype
4917 if error_name == ErrorIf.WrongOutputType:
4918 excludes = [DType.FP32]
4919 wrong_dtypes = list(usableDTypes(excludes=excludes))
4920 output_dtype = rng.choice(wrong_dtypes)
4921 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00004922 output_shape[0] += rng.integers(1, 10)
4923 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4924 modify_dim = rng.choice([1, 2])
4925 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00004926
4927 outputs.append(serializer.addOutput(output_shape, output_dtype))
4928 outputs.append(serializer.addOutput(output_shape, output_dtype))
4929 return outputs