blob: f3ca51209e17d345f5026b837939856cd3f3bef3 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010017from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010019from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000020from tosa.DType import DType
21from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010022
23
Eric Kunzee5e26762020-10-13 16:11:07 -070024class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010025 # Maximum rank of tensor supported by test generator.
26 TOSA_TENSOR_MAX_RANK = 6
27
Eric Kunzee5e26762020-10-13 16:11:07 -070028 def __init__(self, args):
29 self.args = args
30 self.basePath = args.output_dir
31 self.random_seed = args.random_seed
32 self.ser = None
33 self.rng = np.random.default_rng(self.random_seed)
34 self.createDynamicOpLists()
35 self.initOpListDefaults()
36 self.quantGen = TosaQuantGen()
37 # Force makeShape to do a specific starting shape
38 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010039 # Work out floating point range
40 self.random_fp_low = min(args.tensor_fp_value_range)
41 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070042
43 def createSerializer(self, opName, testPath):
44 self.testPath = os.path.join(opName, testPath)
45
46 fullPath = os.path.join(self.basePath, self.testPath)
47 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010048 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070049
50 def getSerializer(self):
51 return self.ser
52
53 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080054 with open(
55 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
56 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070057 fd.write(self.ser.serialize())
58
Kevin Cheng550ccc52021-03-03 11:21:43 -080059 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
60 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070061
Matthew Haddon74567092021-07-16 15:38:20 +010062 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000063 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010064 seed = self.random_seed + 1
65 self.rng = np.random.default_rng(seed)
66
Eric Kunzee5e26762020-10-13 16:11:07 -070067 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070068 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070069 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070070 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070071 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070072 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070073 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010074 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
75 elif dtype == DType.UINT8:
76 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070077 elif dtype == DType.INT16:
78 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010079 elif dtype == DType.UINT16:
80 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070081 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080082 return np.int32(
83 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
84 )
Eric Kunzee5e26762020-10-13 16:11:07 -070085 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080086 return np.int64(
87 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
88 )
James Ward8b390432022-08-12 20:48:56 +010089 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010090 return np.float16(
91 self.rng.uniform(
92 low=self.random_fp_low, high=self.random_fp_high, size=shape
93 )
94 )
James Ward24dbc422022-10-19 12:20:31 +010095 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010096 f32_tensor = np.float32(
97 self.rng.uniform(
98 low=self.random_fp_low, high=self.random_fp_high, size=shape
99 )
100 )
James Ward24dbc422022-10-19 12:20:31 +0100101 # Floor the last 16 bits of each f32 value
102 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100103 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100104 return np.float32(
105 self.rng.uniform(
106 low=self.random_fp_low, high=self.random_fp_high, size=shape
107 )
108 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700109 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800110 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700111
Kevin Cheng989cb052021-04-28 16:29:44 -0700112 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 placeholders = []
114
Kevin Cheng989cb052021-04-28 16:29:44 -0700115 assert len(shape_list) == len(dtype_list)
116
117 for idx, shape in enumerate(shape_list):
118 arr = self.getRandTensor(shape, dtype_list[idx])
119 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 return placeholders
122
Kevin Cheng989cb052021-04-28 16:29:44 -0700123 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700124 consts = []
125
Kevin Cheng989cb052021-04-28 16:29:44 -0700126 assert len(shape_list) == len(dtype_list)
127
128 for idx, shape in enumerate(shape_list):
129 arr = self.getRandTensor(shape, dtype_list[idx])
130 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700131
132 return consts
133
134 def makeShape(self, rank):
135 if self.targetted_shape:
136 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800137 return np.int32(
138 self.rng.integers(
139 low=self.args.tensor_shape_range[0],
140 high=self.args.tensor_shape_range[1],
141 size=rank,
142 )
143 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700144
145 def setTargetShape(self, shape):
146 self.targetted_shape = shape
147
148 def randInt(self, low=0, high=256):
149 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
150
151 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100152 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100153 return np.float32(
154 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
155 )
James Ward8b390432022-08-12 20:48:56 +0100156 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100157 return np.float16(
158 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
159 )
James Ward24dbc422022-10-19 12:20:31 +0100160 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100161 rand_f32 = np.float32(
162 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
163 )
James Ward24dbc422022-10-19 12:20:31 +0100164 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700165 elif dtype == DType.BOOL:
166 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700167 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700168 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700169 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100171 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700172 elif dtype == DType.INT16:
173 low, high = (-32768, 32768)
174 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800175 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700176 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800177 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 # Special size
179 return np.int64(self.rng.integers(low, high, size=1))[0]
180 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700182
183 return np.int32(self.rng.integers(low, high, size=1))[0]
184
185 def shapeStr(self, shape):
186
187 sStr = []
188 # Convert to strings
189 for i in shape:
190 sStr.append(str(i))
191
Kevin Cheng550ccc52021-03-03 11:21:43 -0800192 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700193
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100194 def typeStr(self, dtype):
195 if isinstance(dtype, list) or isinstance(dtype, tuple):
196 assert len(dtype) >= 2
197 strs = [self.typeStr(t) for t in dtype]
198 # Limit types to the first 2 as the 3rd is the accumulator
199 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700200 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100201 if dtype in DTYPE_ATTRIBUTES:
202 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700203 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100204 raise Exception(
205 "Unknown dtype, cannot convert to string: {}".format(dtype)
206 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100208 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100209 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100210 if dtype in DTYPE_ATTRIBUTES:
211 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100213 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700214
215 # Argument generators
216 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
217 # Where the string descriptor is used to generate the test name and
218 # The build_fcn_arg_list is expanded and passed to the operator test
219 # build function
220
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100221 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
222 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
223
Matthew Haddon848efb42021-09-09 12:30:53 +0100224 # build_placeholder returns an int, ABS/other ops does not
225 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000226 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100227 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000228 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000229 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100230 return result_tens
231
232 # Ensure new output type has correct qinfo
233 if error_name == ErrorIf.WrongOutputType:
234 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000235 qinfo = [
236 TosaQuantGen.getZeroPoint(self, a.dtype),
237 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
238 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100239
240 # Invalidate Input/Output list for error if checks.
241 input_list = [a.name]
242 output_list = [result_tens.name]
243 pCount, cCount = op["operands"]
244 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000245 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
246 self, error_name, input_list, output_list
247 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100248
Les Bell729b0352021-11-24 10:28:21 +0000249 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100250 self.ser,
251 validator_fcns,
252 error_name,
253 op=op,
254 input_dtype=a.dtype,
255 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000256 qinfo=qinfo,
257 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100258 input_list=input_list,
259 output_list=output_list,
260 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000261 ):
262 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100263
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000264 attr = None
265 if op["op"] == Op.NEGATE:
266 attr = ts.TosaSerializerAttribute()
267 attr.NegateAttribute(qinfo[0], qinfo[1])
268
269 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700270 return result_tens
271
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100272 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000273 result_tens = OutputShaper.binaryBroadcastOp(
274 self.ser, self.rng, a, b, error_name
275 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100276
277 # Invalidate Input/Output list for error if checks.
278 input_list = [a.name, b.name]
279 output_list = [result_tens.name]
280 pCount, cCount = op["operands"]
281 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000282 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
283 self, error_name, input_list, output_list
284 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100285
Les Bell729b0352021-11-24 10:28:21 +0000286 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100287 self.ser,
288 validator_fcns,
289 error_name,
290 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000291 input1=a,
292 input2=b,
293 input_dtype=a.dtype,
294 output_dtype=result_tens.dtype,
295 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100296 input_list=input_list,
297 output_list=output_list,
298 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000299 ):
300 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100301
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000302 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700303 return result_tens
304
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100305 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700306 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000307 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700308 return result_tens
309
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000310 def build_arithmetic_right_shift(
311 self, op, a, b, round, validator_fcns=None, error_name=None
312 ):
313 result_tens = OutputShaper.binaryBroadcastOp(
314 self.ser, self.rng, a, b, error_name
315 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100316
317 # Invalidate Input/Output list for error if checks.
318 input_list = [a.name, b.name]
319 output_list = [result_tens.name]
320 pCount, cCount = op["operands"]
321 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000322 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
323 self, error_name, input_list, output_list
324 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100325
Les Bell729b0352021-11-24 10:28:21 +0000326 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100327 self.ser,
328 validator_fcns,
329 error_name,
330 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000331 input1=a,
332 input2=b,
333 input_dtype=a.dtype,
334 output_dtype=result_tens.dtype,
335 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100336 input_list=input_list,
337 output_list=output_list,
338 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000339 ):
340 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800341
342 attr = ts.TosaSerializerAttribute()
343 attr.ArithmeticRightShiftAttribute(round)
344
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000345 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800346 return result_tens
347
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100348 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000349 result_tens = OutputShaper.binaryBroadcastOp(
350 self.ser, self.rng, a, b, error_name
351 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700352
353 # Special for multiply:
354 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100355 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700356 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100357 if error_name == ErrorIf.WrongOutputType:
358 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
359 outputDType = self.rng.choice(all_dtypes)
360 result_tens.setDtype(outputDType)
361
362 # Invalidate Input/Output list for error if checks.
363 input_list = [a.name, b.name]
364 output_list = [result_tens.name]
365 pCount, cCount = op["operands"]
366 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000367 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
368 self, error_name, input_list, output_list
369 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100370
Les Bell729b0352021-11-24 10:28:21 +0000371 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100372 self.ser,
373 validator_fcns,
374 error_name,
375 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000376 input1=a,
377 input2=b,
378 input_dtype=a.dtype,
379 output_dtype=result_tens.dtype,
380 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100381 input_list=input_list,
382 output_list=output_list,
383 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000384 ):
385 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700386
Kevin Chengaee1fac2020-11-11 13:54:06 -0800387 attr = ts.TosaSerializerAttribute()
388 attr.MulAttribute(shift)
389
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000390 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 return result_tens
392
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100393 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
394 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700395
Kevin Chengfe392ce2021-10-18 21:51:55 +0000396 attr = ts.TosaSerializerAttribute()
397 attr.TableAttribute(table)
398
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100399 # Invalidate Input/Output list for error if checks.
400 input_list = [a.name]
401 output_list = [result_tens.name]
402 pCount, cCount = op["operands"]
403 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000404 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
405 self, error_name, input_list, output_list
406 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100407
Les Bell729b0352021-11-24 10:28:21 +0000408 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100409 self.ser,
410 validator_fcns,
411 error_name,
412 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000413 input_shape=a.shape,
414 input_dtype=a.dtype,
415 output_dtype=result_tens.dtype,
416 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100417 input_list=input_list,
418 output_list=output_list,
419 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000420 ):
421 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100422
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000423 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700424
425 return result_tens
426
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100427 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
428 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
429
430 # Invalidate Input/Output list for error if checks.
431 input_list = [cond.name, a.name, b.name]
432 output_list = [result_tens.name]
433 pCount, cCount = op["operands"]
434 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000435 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
436 self, error_name, input_list, output_list
437 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100438
Les Bell729b0352021-11-24 10:28:21 +0000439 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100440 self.ser,
441 validator_fcns,
442 error_name,
443 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000444 input1=cond,
445 input2=a,
446 input3=b,
447 input_shape=a.shape,
448 input_dtype=a.dtype,
449 output_dtype=result_tens.dtype,
450 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100451 input_list=input_list,
452 output_list=output_list,
453 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000454 ):
455 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100456
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 self.ser.addOperator(
458 op["op"],
459 input_list,
460 output_list,
461 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700462 return result_tens
463
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100464 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000465 result_tens = OutputShaper.binaryComparisonOp(
466 self.ser, self.rng, a, b, error_name
467 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100468
469 # Invalidate Input/Output list for error if checks.
470 input_list = [a.name, b.name]
471 output_list = [result_tens.name]
472 pCount, cCount = op["operands"]
473 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
475 self, error_name, input_list, output_list
476 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100477
Les Bell729b0352021-11-24 10:28:21 +0000478 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100479 self.ser,
480 validator_fcns,
481 error_name,
482 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000483 input1=a,
484 input2=b,
485 input_shape=a.shape,
486 input_dtype=a.dtype,
487 output_shape=result_tens.shape,
488 output_dtype=result_tens.dtype,
489 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100490 input_list=input_list,
491 output_list=output_list,
492 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000493 ):
494 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100495
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000496 self.ser.addOperator(
497 op["op"],
498 input_list,
499 output_list,
500 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700501 return result_tens
502
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100503 def build_argmax(self, op, a, axis, validator_fcns, error_name):
504 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
505
506 # Invalidate Input/Output list for error if checks.
507 input_list = [a.name]
508 output_list = [result_tens.name]
509 pCount, cCount = op["operands"]
510 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000511 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
512 self, error_name, input_list, output_list
513 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100514
Les Bell729b0352021-11-24 10:28:21 +0000515 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100516 self.ser,
517 validator_fcns,
518 error_name,
519 op=op,
520 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000521 input_shape=a.shape,
522 input_dtype=a.dtype,
523 output_shape=result_tens.shape,
524 output_dtype=result_tens.dtype,
525 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100526 input_list=input_list,
527 output_list=output_list,
528 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000529 ):
530 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700531
532 attr = ts.TosaSerializerAttribute()
533 attr.AxisAttribute(axis)
534
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000535 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700536 return result_tens
537
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000538 def build_pool2d(
539 self,
540 op,
541 input,
James Ward8b390432022-08-12 20:48:56 +0100542 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 stride,
544 pad,
545 kernel,
546 validator_fcns=None,
547 error_name=None,
548 qinfo=None,
549 ):
550 result_tens = OutputShaper.pool2dOp(
551 self.ser, self.rng, input, kernel, stride, pad, error_name
552 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100553
554 # Ensure new output type has correct qinfo
555 if error_name == ErrorIf.WrongInputType:
556 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000557 qinfo = [
558 TosaQuantGen.getZeroPoint(self, input.dtype),
559 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
560 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100561
562 # Invalidate Input/Output list for error if checks.
563 input_list = [input.name]
564 output_list = [result_tens.name]
565 pCount, cCount = op["operands"]
566 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000567 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
568 self, error_name, input_list, output_list
569 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100570
Les Bell729b0352021-11-24 10:28:21 +0000571 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100572 self.ser,
573 validator_fcns,
574 error_name,
575 op=op,
576 input_shape=input.shape,
577 input_dtype=input.dtype,
578 output_shape=result_tens.shape,
579 output_dtype=result_tens.dtype,
580 kernel=kernel,
581 stride=stride,
582 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000583 qinfo=qinfo,
584 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100585 input_list=input_list,
586 output_list=output_list,
587 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000588 ):
589 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700590
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000591 if qinfo is None:
592 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700593
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000594 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100595 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000596
597 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700598 return result_tens
599
James Ward8b390432022-08-12 20:48:56 +0100600 def build_maxpool2d(
601 self,
602 op,
603 input,
604 stride,
605 pad,
606 kernel,
607 validator_fcns=None,
608 error_name=None,
609 qinfo=None,
610 ):
611 # Same as build_pool2d but manually sets accum_dtype value
612 # (maxpool has no accum_dtype)
613 return self.build_pool2d(
614 op,
615 input,
616 DType.UNKNOWN,
617 stride,
618 pad,
619 kernel,
620 validator_fcns,
621 error_name,
622 qinfo,
623 )
624
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000625 def build_conv2d(
626 self,
627 op,
628 ifm,
629 filter,
630 bias,
James Ward8b390432022-08-12 20:48:56 +0100631 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000632 strides,
633 padding,
634 dilations,
635 validator_fcns=None,
636 error_name=None,
637 qinfo=None,
638 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800639 assert len(padding) == 4
640 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100641 self.ser,
642 self.rng,
643 ifm,
644 filter,
645 accum_dtype,
646 strides,
647 padding,
648 dilations,
649 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000650 )
651
652 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000653 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
654 DType.INT8,
655 DType.UINT8,
656 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000657 qinfo = [
658 TosaQuantGen.getZeroPoint(self, ifm.dtype),
659 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
660 ]
Les Bell0e027d42021-11-09 14:42:14 +0000661
662 # Invalidate Input/Output list for error_if checks.
663 input_list = [ifm.name, filter.name, bias.name]
664 output_list = [result_tens.name]
665 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000666 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
667 self, error_name, input_list, output_list
668 )
Les Bell0e027d42021-11-09 14:42:14 +0000669
Les Bell729b0352021-11-24 10:28:21 +0000670 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000671 self.ser,
672 validator_fcns,
673 error_name,
674 op=op,
675 input_dtype=ifm.dtype,
676 weight_dtype=filter.dtype,
677 output_dtype=result_tens.dtype,
678 qinfo=qinfo,
679 input_list=input_list,
680 num_operands=num_operands,
681 output_list=output_list,
682 pad=padding,
683 stride=strides,
684 dilation=dilations,
685 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100686 weight_shape=filter.shape,
687 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000688 ):
689 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700690
691 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100692 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700693
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000694 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700695 return result_tens
696
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000697 def build_conv3d(
698 self,
699 op,
700 ifm,
701 filter,
702 bias,
James Ward8b390432022-08-12 20:48:56 +0100703 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000704 strides,
705 padding,
706 dilations,
707 validator_fcns=None,
708 error_name=None,
709 qinfo=None,
710 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700711 assert len(padding) == 6
712 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100713 self.ser,
714 self.rng,
715 ifm,
716 filter,
717 accum_dtype,
718 strides,
719 padding,
720 dilations,
721 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000722 )
723
724 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000725 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
726 DType.INT8,
727 DType.UINT8,
728 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000729 qinfo = [
730 TosaQuantGen.getZeroPoint(self, ifm.dtype),
731 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
732 ]
Les Bell0e027d42021-11-09 14:42:14 +0000733
734 # Invalidate Input/Output list for error_if checks.
735 input_list = [ifm.name, filter.name, bias.name]
736 output_list = [result_tens.name]
737 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000738 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
739 self, error_name, input_list, output_list
740 )
Les Bell0e027d42021-11-09 14:42:14 +0000741
Les Bell729b0352021-11-24 10:28:21 +0000742 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000743 self.ser,
744 validator_fcns,
745 error_name,
746 op=op,
747 input_dtype=ifm.dtype,
748 weight_dtype=filter.dtype,
749 output_dtype=result_tens.dtype,
750 qinfo=qinfo,
751 input_list=input_list,
752 num_operands=num_operands,
753 output_list=output_list,
754 pad=padding,
755 stride=strides,
756 dilation=dilations,
757 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100758 weight_shape=filter.shape,
759 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000760 ):
761 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700762
763 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100764 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Kevin Cheng1533b852021-09-01 12:51:58 -0700765
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000766 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700767 return result_tens
768
Kevin Cheng550ccc52021-03-03 11:21:43 -0800769 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000770 self,
771 op,
772 ifm,
773 filter,
774 bias,
James Ward8b390432022-08-12 20:48:56 +0100775 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000776 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700777 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000778 output_shape,
779 validator_fcns=None,
780 error_name=None,
781 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800782 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700783 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000784 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100785 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000786 )
Les Bell0e027d42021-11-09 14:42:14 +0000787
788 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000789 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
790 DType.INT8,
791 DType.UINT8,
792 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000793 qinfo = [
794 TosaQuantGen.getZeroPoint(self, ifm.dtype),
795 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
796 ]
Les Bell0e027d42021-11-09 14:42:14 +0000797
798 # Invalidate Input/Output list for error_if checks.
799 input_list = [ifm.name, filter.name, bias.name]
800 output_list = [result_tens.name]
801 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000802 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
803 self, error_name, input_list, output_list
804 )
Les Bell0e027d42021-11-09 14:42:14 +0000805
Les Bell729b0352021-11-24 10:28:21 +0000806 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000807 self.ser,
808 validator_fcns,
809 error_name,
810 op=op,
811 input_dtype=ifm.dtype,
812 weight_dtype=filter.dtype,
813 output_dtype=result_tens.dtype,
814 qinfo=qinfo,
815 input_list=input_list,
816 num_operands=num_operands,
817 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700818 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000819 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000820 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100821 weight_shape=filter.shape,
822 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000823 ):
824 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
826 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100827 attr.TransposeConvAttribute(
828 out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype
829 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000831 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700832 return result_tens
833
Kevin Cheng550ccc52021-03-03 11:21:43 -0800834 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000835 self,
836 op,
837 ifm,
838 filter,
839 bias,
James Ward8b390432022-08-12 20:48:56 +0100840 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000841 strides,
842 padding,
843 dilations,
844 validator_fcns=None,
845 error_name=None,
846 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800847 ):
848 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100849 self.ser,
850 self.rng,
851 ifm,
852 filter,
853 accum_dtype,
854 strides,
855 padding,
856 dilations,
857 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000858 )
859
860 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000861 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
862 DType.INT8,
863 DType.UINT8,
864 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000865 qinfo = [
866 TosaQuantGen.getZeroPoint(self, ifm.dtype),
867 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
868 ]
Les Bell0e027d42021-11-09 14:42:14 +0000869
870 # Invalidate Input/Output list for error_if checks.
871 input_list = [ifm.name, filter.name, bias.name]
872 output_list = [result_tens.name]
873 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000874 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
875 self, error_name, input_list, output_list
876 )
Les Bell0e027d42021-11-09 14:42:14 +0000877
Les Bell729b0352021-11-24 10:28:21 +0000878 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000879 self.ser,
880 validator_fcns,
881 error_name,
882 op=op,
883 input_dtype=ifm.dtype,
884 weight_dtype=filter.dtype,
885 output_dtype=result_tens.dtype,
886 qinfo=qinfo,
887 input_list=input_list,
888 num_operands=num_operands,
889 output_list=output_list,
890 pad=padding,
891 stride=strides,
892 dilation=dilations,
893 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100894 weight_shape=filter.shape,
895 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000896 ):
897 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700898
899 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100900 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700901
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000902 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700903 return result_tens
904
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000905 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100906 self,
907 op,
908 ifm,
909 filter,
910 bias,
911 accum_dtype,
912 validator_fcns=None,
913 error_name=None,
914 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000915 ):
916 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100917 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000918 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100919
920 # Invalidate Input/Output list for error if checks.
921 input_list = [ifm.name, filter.name, bias.name]
922 output_list = [result_tens.name]
923 pCount, cCount = op["operands"]
924 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000925 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
926 self, error_name, input_list, output_list
927 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100928
Les Bell729b0352021-11-24 10:28:21 +0000929 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100930 self.ser,
931 validator_fcns,
932 error_name,
933 op=op,
934 input_shape=ifm.shape,
935 input_dtype=ifm.dtype,
936 weight_dtype=filter.dtype,
937 output_shape=result_tens.shape,
938 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000939 qinfo=qinfo,
940 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100941 input_list=input_list,
942 output_list=output_list,
943 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100944 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000945 ):
946 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700947
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000948 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100949 attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000950
951 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700952 return result_tens
953
James Ward8b390432022-08-12 20:48:56 +0100954 def build_matmul(
955 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
956 ):
957 result_tens = OutputShaper.matmulOp(
958 self.ser, self.rng, a, b, accum_dtype, error_name
959 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100960
961 # Invalidate Input/Output list for error if checks.
962 input_list = [a.name, b.name]
963 output_list = [result_tens.name]
964 pCount, cCount = op["operands"]
965 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000966 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
967 self, error_name, input_list, output_list
968 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100969
Les Bell729b0352021-11-24 10:28:21 +0000970 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100971 self.ser,
972 validator_fcns,
973 error_name,
974 op=op,
975 input_shape=a.shape,
976 input_dtype=a.dtype,
977 input2_shape=b.shape,
978 input2_dtype=b.dtype,
979 output_shape=result_tens.shape,
980 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000981 qinfo=qinfo,
982 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100983 input_list=input_list,
984 output_list=output_list,
985 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100986 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000987 ):
988 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100989
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000990 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100991 attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000992
993 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700994 return result_tens
995
Matthew Haddond6ce7252021-09-29 15:35:44 +0100996 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
997 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
998
999 # Invalidate Input/Output list for error if checks.
1000 input_list = [a.name]
1001 output_list = [result_tens.name]
1002 pCount, cCount = op["operands"]
1003 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001004 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1005 self, error_name, input_list, output_list
1006 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001007
Les Bell729b0352021-11-24 10:28:21 +00001008 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001009 self.ser,
1010 validator_fcns,
1011 error_name,
1012 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001013 axis=axis,
1014 input_shape=a.shape,
1015 output_shape=result_tens.shape,
1016 input_dtype=a.dtype,
1017 output_dtype=result_tens.dtype,
1018 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +01001019 input_list=input_list,
1020 output_list=output_list,
1021 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001022 ):
1023 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001024
1025 attr = ts.TosaSerializerAttribute()
1026 attr.AxisAttribute(axis)
1027
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001029 return result_tens
1030
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001031 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1032 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001033
Jeremy Johnson18e26662021-07-22 16:15:29 +01001034 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001035
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001036 if error_name == ErrorIf.MaxSmallerMin:
1037 # Make sure the numbers are different to invoke this error
1038 while v[0] == v[1]:
1039 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1040 max_val = min(v)
1041 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001042 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001043 max_val = max(v)
1044 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001045
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001046 # Invalidate Input/Output list for error if checks.
1047 input_list = [a.name]
1048 output_list = [result_tens.name]
1049 pCount, cCount = op["operands"]
1050 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001051 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1052 self, error_name, input_list, output_list
1053 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001054
Les Bell729b0352021-11-24 10:28:21 +00001055 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001056 self.ser,
1057 validator_fcns,
1058 error_name,
1059 op=op,
1060 max_val=max_val,
1061 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001062 input_shape=a.shape,
1063 output_shape=result_tens.shape,
1064 input_dtype=a.dtype,
1065 output_dtype=result_tens.dtype,
1066 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001067 input_list=input_list,
1068 output_list=output_list,
1069 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001070 ):
1071 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001072
1073 attr = ts.TosaSerializerAttribute()
James Ward24dbc422022-10-19 12:20:31 +01001074 if a.dtype in (DType.FP16, DType.BF16, DType.FP32):
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001075 attr.ClampAttribute(0, 0, min_val, max_val)
1076 else:
1077 attr.ClampAttribute(min_val, max_val, 0, 0)
1078
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001079 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001080 return result_tens
1081
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001082 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1083 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001084 attr = ts.TosaSerializerAttribute()
1085
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001086 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001087
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001088 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001089 return result_tens
1090
1091 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001092 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1093 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001094
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001095 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001096 return result_tens
1097
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001098 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1099 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1100
1101 # Invalidate Input/Output list for error if checks.
1102 input_list = [a.name]
1103 output_list = [result_tens.name]
1104 pCount, cCount = op["operands"]
1105 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001106 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1107 self, error_name, input_list, output_list
1108 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001109
Les Bell729b0352021-11-24 10:28:21 +00001110 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001111 self.ser,
1112 validator_fcns,
1113 error_name,
1114 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001115 input_shape=a.shape,
1116 output_shape=result_tens.shape,
1117 input_dtype=a.dtype,
1118 output_dtype=result_tens.dtype,
1119 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001120 input_list=input_list,
1121 output_list=output_list,
1122 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001123 ):
1124 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001125
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001126 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001127 return result_tens
1128
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001129 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1130 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1131
1132 # Invalidate Input/Output list for error if checks.
1133 input_list = [a.name]
1134 output_list = [result_tens.name]
1135 pCount, cCount = op["operands"]
1136 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001137 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1138 self, error_name, input_list, output_list
1139 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001140
Les Bell729b0352021-11-24 10:28:21 +00001141 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001142 self.ser,
1143 validator_fcns,
1144 error_name,
1145 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001146 input_shape=a.shape,
1147 output_shape=result_tens.shape,
1148 input_dtype=a.dtype,
1149 output_dtype=result_tens.dtype,
1150 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001151 input_list=input_list,
1152 output_list=output_list,
1153 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001154 ):
1155 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001156
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001157 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001158 return result_tens
1159
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001160 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1161 if error_name != ErrorIf.WrongInputType:
1162 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001163
1164 # To store variable length list of input tensors we need to store axis along with it
1165 axis = a[-1]
1166 a = a[:-1]
1167
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001168 result_tens = OutputShaper.concatOp(
1169 self.ser, self.rng, axis, *a, error_name=error_name
1170 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001171
Matthew Haddon818ab902021-07-27 09:12:49 +01001172 input_tensor_names = []
1173 for tensor in a:
1174 input_tensor_names.append(tensor.name)
1175
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001176 # Invalidate Input/Output list for error if checks.
1177 input_list = input_tensor_names
1178 output_list = [result_tens.name]
1179 pCount, cCount = op["operands"]
1180 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001181 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1182 self, error_name, input_list, output_list
1183 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001184
Les Bell729b0352021-11-24 10:28:21 +00001185 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001186 self.ser,
1187 validator_fcns,
1188 error_name,
1189 op=op,
1190 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001191 input_shape=a[0].shape,
1192 output_shape=result_tens.shape,
1193 input_dtype=a[0].dtype,
1194 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001195 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001196 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001197 input_list=input_list,
1198 output_list=output_list,
1199 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001200 ):
1201 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001202
1203 attr = ts.TosaSerializerAttribute()
1204 attr.AxisAttribute(axis)
1205
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001206 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001207 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001208
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001209 def build_pad(
1210 self,
1211 op,
1212 a,
1213 padding,
1214 pad_const_int,
1215 pad_const_float,
1216 validator_fcns=None,
1217 error_name=None,
1218 qinfo=None,
1219 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001220 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001221
Kevin Chengfe392ce2021-10-18 21:51:55 +00001222 attr = ts.TosaSerializerAttribute()
1223 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001224
Matthew Haddone807aae2021-10-11 18:12:58 +01001225 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001226 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001227 output_list = [result_tens.name]
1228 pCount, cCount = op["operands"]
1229 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001230 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1231 self, error_name, input_list, output_list
1232 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001233
Les Bell729b0352021-11-24 10:28:21 +00001234 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001235 self.ser,
1236 validator_fcns,
1237 error_name,
1238 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001239 input_shape=a.shape,
1240 output_shape=result_tens.shape,
1241 input_dtype=a.dtype,
1242 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001243 pad=padding,
1244 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001245 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001246 input_list=input_list,
1247 output_list=output_list,
1248 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001249 ):
1250 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001251
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001252 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001253 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001254
Matthew Haddone807aae2021-10-11 18:12:58 +01001255 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001256 result_tens = OutputShaper.reshapeOp(
1257 self.ser, self.rng, a, newShape, error_name
1258 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001259
1260 # Invalidate Input/Output list for error if checks.
1261 input_list = [a.name]
1262 output_list = [result_tens.name]
1263 pCount, cCount = op["operands"]
1264 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001265 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1266 self, error_name, input_list, output_list
1267 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001268
Les Bell729b0352021-11-24 10:28:21 +00001269 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001270 self.ser,
1271 validator_fcns,
1272 error_name,
1273 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001274 input_shape=a.shape,
1275 output_shape=result_tens.shape,
1276 input_dtype=a.dtype,
1277 output_dtype=result_tens.dtype,
1278 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001279 input_list=input_list,
1280 output_list=output_list,
1281 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001282 ):
1283 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001284
1285 attr = ts.TosaSerializerAttribute()
1286 attr.ReshapeAttribute(newShape)
1287
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001288 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001289 return result_tens
1290
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001291 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1292 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1293
1294 # Invalidate Input/Output list for error if checks.
1295 input_list = [a.name]
1296 output_list = [result_tens.name]
1297 pCount, cCount = op["operands"]
1298 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001299 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1300 self, error_name, input_list, output_list
1301 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001302
Les Bell729b0352021-11-24 10:28:21 +00001303 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001304 self.ser,
1305 validator_fcns,
1306 error_name,
1307 op=op,
1308 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001309 input_shape=a.shape,
1310 output_shape=result_tens.shape,
1311 input_dtype=a.dtype,
1312 output_dtype=result_tens.dtype,
1313 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001314 input_list=input_list,
1315 output_list=output_list,
1316 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001317 ):
1318 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001319
1320 attr = ts.TosaSerializerAttribute()
1321 attr.AxisAttribute(axis)
1322
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001323 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001324 return result_tens
1325
Matthew Haddone807aae2021-10-11 18:12:58 +01001326 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1327 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001328
Kevin Chengfe392ce2021-10-18 21:51:55 +00001329 attr = ts.TosaSerializerAttribute()
1330 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001331
Matthew Haddone807aae2021-10-11 18:12:58 +01001332 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001333 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001334 output_list = [result_tens.name]
1335 pCount, cCount = op["operands"]
1336 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001337 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1338 self, error_name, input_list, output_list
1339 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001340
Les Bell729b0352021-11-24 10:28:21 +00001341 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001342 self.ser,
1343 validator_fcns,
1344 error_name,
1345 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001346 input_shape=a.shape,
1347 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001348 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 input_dtype=a.dtype,
1350 output_dtype=result_tens.dtype,
1351 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001352 input_list=input_list,
1353 output_list=output_list,
1354 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001355 ):
1356 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001357
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001358 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001359 return result_tens
1360
Matthew Haddone807aae2021-10-11 18:12:58 +01001361 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001362 result_tens = OutputShaper.sliceOp(
1363 self.ser, self.rng, a, start, size, error_name
1364 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001365
1366 # Invalidate Input/Output list for error if checks.
1367 input_list = [a.name]
1368 output_list = [result_tens.name]
1369 pCount, cCount = op["operands"]
1370 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1372 self, error_name, input_list, output_list
1373 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001374
Les Bell729b0352021-11-24 10:28:21 +00001375 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001376 self.ser,
1377 validator_fcns,
1378 error_name,
1379 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_shape=a.shape,
1381 output_shape=result_tens.shape,
1382 input_dtype=a.dtype,
1383 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001384 start=start,
1385 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001386 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001387 input_list=input_list,
1388 output_list=output_list,
1389 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001390 ):
1391 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001392
1393 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001394 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001395
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001396 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001397 return result_tens
1398
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1400 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1401
1402 # Invalidate Input/Output list for error if checks.
1403 input_list = [a.name]
1404 output_list = [result_tens.name]
1405 pCount, cCount = op["operands"]
1406 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001407 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1408 self, error_name, input_list, output_list
1409 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001410
Les Bell729b0352021-11-24 10:28:21 +00001411 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001412 self.ser,
1413 validator_fcns,
1414 error_name,
1415 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001416 input_shape=a.shape,
1417 output_shape=result_tens.shape,
1418 input_dtype=a.dtype,
1419 output_dtype=result_tens.dtype,
1420 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001421 input_list=input_list,
1422 output_list=output_list,
1423 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001424 ):
1425 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001426
1427 attr = ts.TosaSerializerAttribute()
1428 attr.TileAttribute(multiples)
1429
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001430 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001431 return result_tens
1432
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001433 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001434
1435 # Create a new indicies tensor
1436 # here with data that doesn't exceed the dimensions of the values tensor
1437
Kevin Cheng550ccc52021-03-03 11:21:43 -08001438 K = values.shape[1] # K
1439 W = self.randInt(
1440 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1441 ) # W
1442 indicies_arr = np.int32(
1443 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1444 ) # (N, W)
1445 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 result_tens = OutputShaper.gatherOp(
1448 self.ser, self.rng, values, indicies, error_name
1449 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001450
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001451 # Invalidate Input/Output list for error if checks.
1452 input_list = [values.name, indicies.name]
1453 output_list = [result_tens.name]
1454 pCount, cCount = op["operands"]
1455 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001456 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1457 self, error_name, input_list, output_list
1458 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001459
Les Bell729b0352021-11-24 10:28:21 +00001460 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001461 self.ser,
1462 validator_fcns,
1463 error_name,
1464 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001465 input_shape=values.shape,
1466 output_shape=result_tens.shape,
1467 input_dtype=values.dtype,
1468 output_dtype=result_tens.dtype,
1469 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001470 input_list=input_list,
1471 output_list=output_list,
1472 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001473 ):
1474 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001475
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001476 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001477
1478 return result_tens
1479
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001480 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001481
1482 # Create a new indicies tensor
1483 # here with data that doesn't exceed the dimensions of the values_in tensor
1484
Kevin Cheng550ccc52021-03-03 11:21:43 -08001485 K = values_in.shape[1] # K
1486 W = input.shape[1] # W
1487 indicies_arr = np.int32(
1488 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1489 ) # (N, W)
1490 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001491
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001492 result_tens = OutputShaper.scatterOp(
1493 self.ser, self.rng, values_in, indicies, input, error_name
1494 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001495
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001496 # Invalidate Input/Output list for error if checks.
1497 input_list = [values_in.name, indicies.name, input.name]
1498 output_list = [result_tens.name]
1499 pCount, cCount = op["operands"]
1500 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001501 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1502 self, error_name, input_list, output_list
1503 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001504
Les Bell729b0352021-11-24 10:28:21 +00001505 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001506 self.ser,
1507 validator_fcns,
1508 error_name,
1509 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001510 input_shape=values_in.shape,
1511 output_shape=result_tens.shape,
1512 input_dtype=values_in.dtype,
1513 output_dtype=result_tens.dtype,
1514 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001515 input_list=input_list,
1516 output_list=output_list,
1517 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001518 ):
1519 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001520
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001521 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001522
Kevin Cheng77d0f762020-11-24 10:26:32 -08001523 return result_tens
1524
Kevin Cheng550ccc52021-03-03 11:21:43 -08001525 def build_resize(
1526 self,
1527 op,
1528 input,
1529 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001530 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001531 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001532 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001533 input_dtype,
1534 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001535 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001536 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001537 ):
1538 result_tens = OutputShaper.resizeOp(
1539 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001540 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001541 input,
1542 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001543 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001544 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001545 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001546 input_dtype,
1547 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001548 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001549 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001550
Matthew Haddon848efb42021-09-09 12:30:53 +01001551 # Invalidate Input/Output list for error if checks.
1552 input_list = [input.name]
1553 output_list = [result_tens.name]
1554 pCount, cCount = op["operands"]
1555 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001556 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1557 self, error_name, input_list, output_list
1558 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001559
Les Bell729b0352021-11-24 10:28:21 +00001560 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001561 self.ser,
1562 validator_fcns,
1563 error_name,
1564 op=op,
1565 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001566 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001567 input_dtype=input_dtype,
1568 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001569 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001570 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001571 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001572 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001573 input_list=input_list,
1574 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001575 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001576 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001577 ):
1578 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001579
Eric Kunzee5e26762020-10-13 16:11:07 -07001580 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001581
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001582 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001583
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001584 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001585 return result_tens
1586
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001587 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1588 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1589 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001590 self.ser.addOperator(
1591 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1592 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001593 return result_tens
1594
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001595 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001596 self.ser.addOutputTensor(val)
1597 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001598
1599 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001600 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001601 result_tens = OutputShaper.typeConversionOp(
1602 self.ser, self.rng, val, out_dtype, error_name
1603 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001604
1605 # Invalidate Input/Output list for error if checks.
1606 input_list = [val.name]
1607 output_list = [result_tens.name]
1608 pCount, cCount = op["operands"]
1609 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001610 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1611 self, error_name, input_list, output_list
1612 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001613
Les Bell729b0352021-11-24 10:28:21 +00001614 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001615 self.ser,
1616 validator_fcns,
1617 error_name,
1618 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001619 input_shape=val.shape,
1620 output_shape=result_tens.shape,
1621 input_dtype=val.dtype,
1622 output_dtype=result_tens.dtype,
1623 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001624 input_list=input_list,
1625 output_list=output_list,
1626 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001627 ):
1628 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001629
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001630 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001631 return result_tens
1632
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001633 def build_rescale(
1634 self,
1635 op,
1636 val,
1637 out_dtype,
1638 scale32,
1639 double_round,
1640 per_channel,
1641 validator_fcns,
1642 error_name,
1643 ):
1644 result_tens = OutputShaper.typeConversionOp(
1645 self.ser, self.rng, val, out_dtype, error_name
1646 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001647
1648 if per_channel:
1649 nc = val.shape[-1]
1650 else:
1651 nc = 1
1652
1653 in_type_width = self.typeWidth(val.dtype)
1654 out_type_width = self.typeWidth(out_dtype)
1655
Kevin Cheng3a478572021-01-22 17:21:02 -08001656 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001657 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001658 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001659 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001660 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001661 in_type_width += 1
1662 elif error_name in [
1663 ErrorIf.InputZeroPointNotZero,
1664 ErrorIf.U16InputZeroPointNotValid,
1665 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001666 input_zp = self.randInt(-128, 128)
1667 if input_zp == 0:
1668 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001669 in_type_width += 1
1670 elif val.dtype == DType.UINT16:
1671 # Must come after ErrorIf.U16InputZeroPointNotValid check
1672 input_zp = self.rng.choice([0, 32768])
1673 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001674 else:
1675 input_zp = 0
1676
Kevin Cheng3a478572021-01-22 17:21:02 -08001677 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001678 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001679 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001680 elif out_dtype == DType.UINT8:
1681 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001682 out_type_width += 1
1683 elif error_name in [
1684 ErrorIf.OutputZeroPointNotZero,
1685 ErrorIf.U16OutputZeroPointNotValid,
1686 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001687 output_zp = self.randInt(-128, 128)
1688 if output_zp == 0:
1689 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001690 out_type_width += 1
1691 elif out_dtype == DType.UINT16:
1692 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1693 output_zp = self.rng.choice([0, 32768])
1694 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001695 else:
1696 output_zp = 0
1697
1698 # Calculate scale based on:
1699 # scale = a *(2^output_width)/(2^input_width))
1700
1701 a = np.float32(self.rng.random(size=[nc]))
1702 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1703
1704 if scale32:
1705 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001706 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001707 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1708 else:
1709 # Cap the scaling at 2^15 - 1 for scale16
1710 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1711
Kevin Cheng550ccc52021-03-03 11:21:43 -08001712 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001713
1714 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1715 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001716 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1717 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001718
1719 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001720 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1721 scale_arr[i], scale32
1722 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001723 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1724 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001725
Kevin Cheng550ccc52021-03-03 11:21:43 -08001726 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001727 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001728 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001729 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001730 assert val.placeholderFilename
1731 values = np.load(
1732 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1733 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001734 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1735 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1736 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1737 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001738 if not np.all(np.array_equal(values, val_adj)):
1739 # Values changed so overwrite file with new values
1740 np.save(
1741 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1742 val_adj,
1743 False,
1744 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001745
Matthew Haddonc2025212021-10-08 21:21:05 +01001746 # Invalidate Input/Output list for error if checks.
1747 input_list = [val.name]
1748 output_list = [result_tens.name]
1749 pCount, cCount = op["operands"]
1750 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001751 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1752 self, error_name, input_list, output_list
1753 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001754
1755 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001756 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001757 self.ser,
1758 validator_fcns,
1759 error_name,
1760 op=op,
1761 input_dtype=val.dtype,
1762 output_dtype=out_dtype,
1763 input_shape=val.shape,
1764 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001765 scale32=scale32,
1766 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001767 input_list=input_list,
1768 output_list=output_list,
1769 result_tensor=result_tens,
1770 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001771 ):
1772 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001773
Eric Kunzee5e26762020-10-13 16:11:07 -07001774 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001775 attr.RescaleAttribute(
1776 input_zp,
1777 output_zp,
1778 multiplier_arr,
1779 shift_arr,
1780 scale32,
1781 double_round,
1782 per_channel,
1783 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001784
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001785 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001786 return result_tens
1787
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001788 def build_cond_if_const(
1789 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1790 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001791 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1792 # (except for the generated shap) and the condition. Build Then/Else blocks
1793 # and fill them with const nodes for the body.
1794
1795 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001796 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001797
1798 # Make then/else tensors
1799 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001800
1801 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001802 if error_name in [
1803 ErrorIf.CondIfOutputListThenGraphMismatch,
1804 ErrorIf.CondIfOutputListElseGraphMismatch,
1805 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001806 incorrect_shape = deepcopy(then_tens.shape)
1807 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 incorrect_shape[i] += (
1809 self.rng.choice([-3, -2, 2, 3])
1810 if incorrect_shape[i] > 3
1811 else self.rng.choice([1, 2, 4])
1812 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001813 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1814
Jeremy Johnson18e26662021-07-22 16:15:29 +01001815 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1816 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001817
1818 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001819 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001820
1821 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001822 then_block = "THEN_BLOCK"
1823 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001824 attr = ts.TosaSerializerAttribute()
1825 attr.CondIfAttribute(then_block, else_block)
1826
1827 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001828 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001829
1830 self.ser.startBasicBlock(then_block)
1831 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001832 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1833 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1834 else:
1835 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001836 self.ser.addOutputTensor(then_tens)
1837
1838 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001839 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1840 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1841 else:
1842 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001843 self.ser.addOutputTensor(else_tens)
1844
Les Bell729b0352021-11-24 10:28:21 +00001845 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001846 self.ser,
1847 validator_fcns,
1848 error_name,
1849 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001850 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001851 ):
1852 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001853
Eric Kunzee5e26762020-10-13 16:11:07 -07001854 return result_tens
1855
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001856 def build_cond_if_binary(
1857 self, op, a, b, cond, validator_fcns=None, error_name=None
1858 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001859 # For cond_if with a binary op in the then/else blocks, take a and b and
1860 # alternately add or subtract them based on the condition
1861
1862 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001863 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001864
Kevin Cheng550ccc52021-03-03 11:21:43 -08001865 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001866
1867 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001868 then_block = "THEN_BLOCK"
1869 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001870 attr = ts.TosaSerializerAttribute()
1871 attr.CondIfAttribute(then_block, else_block)
1872
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001873 if error_name in [
1874 ErrorIf.CondIfInputListThenGraphMismatch,
1875 ErrorIf.CondIfInputListElseGraphMismatch,
1876 ErrorIf.CondIfOutputListElseGraphMismatch,
1877 ErrorIf.CondIfOutputListThenGraphMismatch,
1878 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001879 incorrect_shape = a.shape.copy()
1880 for i in range(len(incorrect_shape)):
1881 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1882 incorrect_block_input = deepcopy(a)
1883 incorrect_block_input.shape = incorrect_shape
1884
Eric Kunzee5e26762020-10-13 16:11:07 -07001885 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001886 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001887 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001888 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001889
James Ward24dbc422022-10-19 12:20:31 +01001890 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001891 then_op, else_op = Op.ADD, Op.SUB
1892 elif a.dtype in (DType.INT8, DType.INT16):
1893 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1894 else:
1895 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001896
Les Bell6040b4d2021-10-11 12:50:31 +01001897 for block, op in ((then_block, then_op), (else_block, else_op)):
1898 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001899 if (
1900 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1901 and block == then_block
1902 ) or (
1903 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1904 and block == else_block
1905 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001906 self.ser.addInputTensor(incorrect_block_input)
1907 self.ser.addInputTensor(b)
1908 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001909 elif (
1910 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1911 and block == then_block
1912 ) or (
1913 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1914 and block == else_block
1915 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001916 self.ser.addInputTensor(a)
1917 self.ser.addInputTensor(b)
1918 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1919 else:
1920 self.ser.addInputTensor(a)
1921 self.ser.addInputTensor(b)
1922 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001923 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001924
Les Bell729b0352021-11-24 10:28:21 +00001925 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001926 self.ser,
1927 validator_fcns,
1928 error_name,
1929 op=op,
1930 a=a,
1931 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001932 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001933 ):
1934 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001935
Eric Kunzee5e26762020-10-13 16:11:07 -07001936 return result_tens
1937
Matthew Haddon630c17c2021-10-14 15:05:41 +01001938 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001939 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001940
Kevin Cheng550ccc52021-03-03 11:21:43 -08001941 cond_block = "COND_BLOCK"
1942 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
1944 attr = ts.TosaSerializerAttribute()
1945 attr.WhileLoopAttribute(cond_block, body_block)
1946
1947 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001948 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001949 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001950 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001951
1952 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001953 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1954 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001955 if error_name == ErrorIf.InputListOutputListMismatch:
1956 incorrect_acc = deepcopy(acc)
1957 for i in range(len(incorrect_acc.shape)):
1958 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1959 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1960 else:
1961 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001962
1963 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001964 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001965 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001966 [iter.name, a.name, acc.name],
1967 [iter_out.name, a_out.name, acc_out.name],
1968 attr,
1969 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001970 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001971
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001972 if error_name in [
1973 ErrorIf.InputListCondGraphMismatch,
1974 ErrorIf.InputListBodyGraphInputMismatch,
1975 ErrorIf.InputListBodyGraphOutputMismatch,
1976 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001977 incorrect_iter = deepcopy(iter)
1978 for i in range(len(incorrect_iter.shape)):
1979 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
1980 if len(incorrect_iter.shape) == 0:
1981 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
1982
1983 incorrect_acc = deepcopy(acc)
1984 for i in range(len(incorrect_acc.shape)):
1985 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1986
Eric Kunzee5e26762020-10-13 16:11:07 -07001987 # COND block (input: iter, output: cond_tens )
1988 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001989 if error_name == ErrorIf.InputListCondGraphMismatch:
1990 self.ser.addInputTensor(incorrect_iter)
1991 self.ser.addInputTensor(a)
1992 self.ser.addInputTensor(incorrect_acc)
1993 else:
1994 self.ser.addInputTensor(iter)
1995 self.ser.addInputTensor(a)
1996 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001997 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001998
1999 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002000 cond_tens = self.ser.addOutput(
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002001 [], self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002002 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002003 else:
2004 cond_tens = self.ser.addOutput([], DType.BOOL)
2005
Kevin Cheng550ccc52021-03-03 11:21:43 -08002006 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002007
2008 # BODY block (input: a, acc, iter, output: a, acc, iter)
2009 # Note that local intermediate tensors need to be declared here for the outputs
2010 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002011 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2012 self.ser.addInputTensor(incorrect_iter)
2013 self.ser.addInputTensor(a)
2014 self.ser.addInputTensor(incorrect_acc)
2015 else:
2016 self.ser.addInputTensor(iter)
2017 self.ser.addInputTensor(a)
2018 self.ser.addInputTensor(acc)
2019
Kevin Cheng550ccc52021-03-03 11:21:43 -08002020 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002021
2022 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002023 iter_body_out = self.ser.addIntermediate(
2024 incorrect_iter.shape, incorrect_iter.dtype
2025 )
2026 acc_body_out = self.ser.addIntermediate(
2027 incorrect_acc.shape, incorrect_acc.dtype
2028 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002029 else:
2030 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2031 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2032
Eric Kunzee5e26762020-10-13 16:11:07 -07002033 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2034 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2035 self.ser.addOutputTensor(iter_body_out)
2036 self.ser.addOutputTensor(a)
2037 self.ser.addOutputTensor(acc_body_out)
2038
Les Bell729b0352021-11-24 10:28:21 +00002039 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002040 self.ser,
2041 validator_fcns,
2042 error_name,
2043 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002044 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002045 ):
2046 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002047
Eric Kunzee5e26762020-10-13 16:11:07 -07002048 return acc_out
2049
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002050 def create_filter_lists(
2051 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2052 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002053 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2054 default_test_rank_range = range(1, 5)
2055 if not shapeFilter:
2056 shapeFilter = [None]
2057
2058 # Calculate the filters based on what is requested and what the operator allows
2059 rmin, rmax = op["rank"]
2060 if rankFilter is not None:
2061 cleanRankFilter = []
2062 # Ensure rankFilter values are allowed by operator
2063 for rank in rankFilter:
2064 if rank >= rmin and rank <= rmax:
2065 cleanRankFilter.append(rank)
2066 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002067 # Ensure default behaviour is bounded by default range or by operator,
2068 # whichever is the smaller range of ranks.
2069 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002070 cleanRankFilter = (
2071 opRankRange
2072 if len(opRankRange) <= len(default_test_rank_range)
2073 else default_test_rank_range
2074 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002075 else:
2076 cleanRankFilter = range(rmin, rmax + 1)
2077
2078 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002079
Matthew Haddon1c00b712021-10-01 15:51:03 +01002080 if dtypeFilter is not None:
2081 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002082 # Create list of operator dtypes filtered by requested dtypes
2083 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002084 if dtype in dtypeFilter or (
2085 isinstance(dtype, list) and dtype[0] in dtypeFilter
2086 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002087 cleanDtypeFilter.append(dtype)
2088 else:
2089 cleanDtypeFilter = dtypes
2090
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002091 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002092 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002093 "shapeFilter": shapeFilter,
2094 "rankFilter": cleanRankFilter,
2095 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002096 }
2097 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002098 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002099 if validator is not None:
2100 validator_info = validator(check=False, op=op)
2101 else:
2102 return None
2103
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002104 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002105
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002106 # Set parameters as required
2107 if error_arguments["rank"] is not None:
2108 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002109 else:
2110 rankFilter = cleanRankFilter
2111
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002112 if error_arguments["dtype"] is not None:
2113 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002114 else:
2115 dtypeFilter = cleanDtypeFilter
2116
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002117 if error_arguments["shape"] is not None:
2118 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002119 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002120 shapeFilter = shapeFilter[
2121 :2
2122 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002123
2124 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002125 "shapeFilter": shapeFilter,
2126 "rankFilter": rankFilter,
2127 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002128 }
2129 return filterDict
2130
Kevin Cheng550ccc52021-03-03 11:21:43 -08002131 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002132 self,
2133 opName,
2134 shapeFilter=[None],
2135 rankFilter=None,
2136 dtypeFilter=None,
2137 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002138 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002139
2140 try:
2141 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002143 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002144
2145 # Initialize a new random number generator
2146 self.rng = np.random.default_rng(self.random_seed)
2147
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002148 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002149
Eric Kunzee5e26762020-10-13 16:11:07 -07002150 # Test list consists of a tuple of:
2151 # (opName, testNameStr, dtype, shapeList, argumentsList)
2152 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002153 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002154 error_if_validators = op["error_if_validators"]
2155 else:
2156 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002157
Matthew Haddon1c00b712021-10-01 15:51:03 +01002158 for validator in error_if_validators:
2159 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002160 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002161 else:
2162 error_name = None
2163
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002164 filterDict = self.create_filter_lists(
2165 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2166 )
2167 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002168 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002169 cleanRankFilter = filterDict["rankFilter"]
2170 cleanDtypeFilter = filterDict["dtypeFilter"]
2171 cleanShapeFilter = filterDict["shapeFilter"]
2172 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002173
2174 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002175 for t in cleanDtypeFilter:
2176 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002177 # Filter out by rank
2178 if shape is not None and len(shape) != r:
2179 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002180 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002181 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002182
Matthew Haddon74567092021-07-16 15:38:20 +01002183 shapeStr = self.shapeStr(shapeList[0])
2184 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002185
Matthew Haddon74567092021-07-16 15:38:20 +01002186 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2187 argList = []
2188 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002189 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002190 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002191 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002192
Matthew Haddon74567092021-07-16 15:38:20 +01002193 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002194 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002195 if argStr:
2196 testStr = "{}_{}_{}_{}".format(
2197 opName, shapeStr, typeStr, argStr
2198 )
2199 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002200 testStr = "{}_{}_{}".format(
2201 opName, shapeStr, typeStr
2202 )
2203 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002204 if argStr:
2205 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2206 opName, error_name, shapeStr, typeStr, argStr
2207 )
2208 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002209 testStr = "{}_ERRORIF_{}_{}_{}".format(
2210 opName, error_name, shapeStr, typeStr
2211 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002212
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002213 testList.append(
2214 (opName, testStr, t, error_name, shapeList, args)
2215 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002216
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002217 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002218 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2219 if "invalid_test_validators" in op:
2220 invalid_test_validators = op["invalid_test_validators"]
2221 clean_testList = []
2222 for test in testList:
2223 for validator_fcn in invalid_test_validators:
2224 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002225 if validator_fcn(
2226 opName=test[0],
2227 input_dtype=test[2],
2228 shapeList=test[4],
2229 args=test[5],
2230 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002231 remove_test = True
2232 if not remove_test:
2233 clean_testList.append(test)
2234 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002235
2236 return testList
2237
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002238 def serializeTest(
2239 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2240 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002241 try:
2242 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002243 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002244 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002245
2246 # Create a serializer
2247 self.createSerializer(opName, testStr)
2248
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002249 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002250 if "error_if_validators" in op:
2251 error_if_validators = op["error_if_validators"]
2252 else:
2253 error_if_validators = None
2254
Kevin Cheng550ccc52021-03-03 11:21:43 -08002255 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002256 num_operands = pCount + cCount
2257
2258 if isinstance(dtype_or_dtypeList, list):
2259 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002260 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002261 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002262 else:
2263 dtypeList = [dtype_or_dtypeList] * (num_operands)
2264
Kevin Cheng93a16282021-08-31 16:14:03 -07002265 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002266 assert (
2267 len(shapeList) == num_operands
2268 ), "shapeList length {} must match number of operands {}".format(
2269 len(shapeList), num_operands
2270 )
2271 assert (
2272 len(dtypeList) == num_operands
2273 ), "dtypeList length {} must match number of operands {}".format(
2274 len(dtypeList), num_operands
2275 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002276
2277 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002278 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002279 except KeyError:
2280 qgen = None
2281
2282 # Build the random tensor operands and the test
2283 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002284
Matthew Haddon1c00b712021-10-01 15:51:03 +01002285 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002286 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002287 else:
2288 qinfo = None
2289
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002290 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002291
Matthew Haddon1c00b712021-10-01 15:51:03 +01002292 try:
2293 if error_if_validators is None:
2294 if qinfo is not None:
2295 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2296 else:
2297 resultName = build_fcn(self, op, *tens, *testArgs)
2298 else:
2299 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002300 resultName = build_fcn(
2301 self,
2302 op,
2303 *tens,
2304 *testArgs,
2305 validator_fcns=error_if_validators,
2306 error_name=error_name,
2307 qinfo=qinfo,
2308 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002309 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002310 resultName = build_fcn(
2311 self,
2312 op,
2313 *tens,
2314 *testArgs,
2315 validator_fcns=error_if_validators,
2316 error_name=error_name,
2317 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002318 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002319 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002320 raise e
2321
Les Bell729b0352021-11-24 10:28:21 +00002322 if resultName:
2323 # The test is valid, serialize it
2324 self.serialize("test")
2325 else:
2326 # The test is not valid
2327 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002328
Eric Kunzee5e26762020-10-13 16:11:07 -07002329 def createDynamicOpLists(self):
2330
Jeremy Johnson00423432022-09-12 17:27:37 +01002331 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2332 # Already created these lists (can occur when class is initialized more than once)
2333 return
2334
Eric Kunzee5e26762020-10-13 16:11:07 -07002335 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002336 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002337
Kevin Cheng1533b852021-09-01 12:51:58 -07002338 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002339 testName = "conv2d_{}x{}".format(k[0], k[1])
2340 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2341 self.TOSA_OP_LIST[testName]["filter"] = k
2342 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002343
Kevin Cheng550ccc52021-03-03 11:21:43 -08002344 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2345 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2346 "depthwise_conv2d_TEMPLATE"
2347 ].copy()
2348 self.TOSA_OP_LIST[testName]["filter"] = k
2349 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002350
Kevin Cheng550ccc52021-03-03 11:21:43 -08002351 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2352 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2353 "transpose_conv2d_TEMPLATE"
2354 ].copy()
2355 self.TOSA_OP_LIST[testName]["filter"] = k
2356 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002357
Kevin Cheng1533b852021-09-01 12:51:58 -07002358 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2359 for k in KERNELS_3D:
2360 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2361 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2362 self.TOSA_OP_LIST[testName]["filter"] = k
2363 self.TOSA_OP_LIST[testName]["template"] = False
2364
Eric Kunzee5e26762020-10-13 16:11:07 -07002365 # Delete any templates after having created any dynamic ops
2366 # This is a two-pass operation because it's bad practice to delete
2367 # keys from dictionaries while iterating
2368 keyList = []
2369 for k in self.TOSA_OP_LIST:
2370 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002371 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002372 keyList.append(k)
2373 continue
2374 except KeyError:
2375 pass
2376
2377 for k in keyList:
2378 del self.TOSA_OP_LIST[k]
2379
2380 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002381 """Fill in default fields for ops if they aren't already specified.
2382 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002383 for op in self.TOSA_OP_LIST:
2384
2385 # Required fields
2386 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002387 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002388 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002389 raise Exception(
2390 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2391 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002392
2393 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002394 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002395 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002396 raise Exception(
2397 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2398 op
2399 )
2400 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002401
2402 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002403 _ = self.TOSA_OP_LIST[op]["types"]
2404 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002405 raise Exception(
2406 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2407 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002408
2409 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002410 _ = self.TOSA_OP_LIST[op]["op"]
2411 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002412 raise Exception(
2413 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2414 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002415
2416 # Put in default rank range, if missing
2417 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002418 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002419 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002420 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002421
2422 # Tensor operator list
2423 # 'op': op name
2424 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002425 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2426 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002427 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2428 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002429 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002430
Kevin Cheng550ccc52021-03-03 11:21:43 -08002431 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002432 TYPE_INT_FP = [
2433 DType.INT8,
2434 DType.INT16,
2435 DType.INT32,
2436 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002437 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002438 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002439 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002440
Kevin Cheng550ccc52021-03-03 11:21:43 -08002441 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002442 TYPE_FI32 = [
2443 DType.FP32,
2444 DType.FP16,
2445 DType.BF16,
2446 DType.INT32,
2447 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002448 TYPE_FIB = [
2449 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002450 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002451 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002452 DType.INT8,
2453 DType.INT16,
2454 DType.INT32,
2455 DType.BOOL,
2456 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002457 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002458
James Ward24dbc422022-10-19 12:20:31 +01002459 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002460
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002461 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002462 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002463 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002464 [DType.INT8, DType.INT8, DType.INT32],
2465 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002466 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002467 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002468 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002469 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002470 ]
2471
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002472 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002473
2474 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002475 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002476 "argmax": {
2477 "op": Op.ARGMAX,
2478 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002479 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002480 "build_fcn": (
2481 build_argmax,
2482 TosaTensorGen.tgBasic,
2483 TosaTensorValuesGen.tvgDefault,
2484 TosaArgGen.agAxis,
2485 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002486 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002487 "error_if_validators": (
2488 TosaErrorValidator.evAxisSmallerZero,
2489 TosaErrorValidator.evAxisLargerRank,
2490 TosaErrorValidator.evArgmaxOutputRankMismatch,
2491 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2492 TosaErrorValidator.evWrongRank,
2493 TosaErrorValidator.evWrongInputType,
2494 TosaErrorValidator.evWrongOutputType,
2495 TosaErrorValidator.evWrongInputList,
2496 TosaErrorValidator.evWrongOutputList,
2497 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002498 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002499 "avg_pool2d": {
2500 "op": Op.AVG_POOL2D,
2501 "operands": (1, 0),
2502 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002503 "build_fcn": (
2504 build_pool2d,
2505 TosaTensorGen.tgNHWC,
2506 TosaTensorValuesGen.tvgDefault,
2507 TosaArgGen.agPooling,
2508 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002509 "qgen": TosaQuantGen.qgUnary,
2510 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002511 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002512 "error_if_validators": (
2513 TosaErrorValidator.evKernelSmallerOne,
2514 TosaErrorValidator.evStrideSmallerOne,
2515 TosaErrorValidator.evPadSmallerZero,
2516 TosaErrorValidator.evWrongRank,
2517 TosaErrorValidator.evWrongInputType,
2518 TosaErrorValidator.evWrongOutputType,
2519 TosaErrorValidator.evWrongInputList,
2520 TosaErrorValidator.evWrongOutputList,
2521 TosaErrorValidator.evInputZeroPointNotZero,
2522 TosaErrorValidator.evOutputZeroPointNotZero,
2523 TosaErrorValidator.evPadLargerEqualKernel,
2524 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002525 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002526 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002527 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002528 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002529 "conv2d_TEMPLATE": {
2530 "op": Op.CONV2D,
2531 "operands": (1, 2),
2532 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002533 "build_fcn": (
2534 build_conv2d,
2535 TosaTensorGen.tgConv2D,
2536 TosaTensorValuesGen.tvgDefault,
2537 TosaArgGen.agConv,
2538 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002540 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002541 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2542 "error_if_validators": (
2543 TosaErrorValidator.evWrongInputType,
2544 TosaErrorValidator.evWrongOutputType,
2545 TosaErrorValidator.evWrongInputList,
2546 TosaErrorValidator.evWrongOutputList,
2547 TosaErrorValidator.evInputZeroPointNotZero,
2548 TosaErrorValidator.evWeightZeroPointNotZero,
2549 TosaErrorValidator.evPadSmallerZero,
2550 TosaErrorValidator.evStrideSmallerOne,
2551 TosaErrorValidator.evDilationSmallerOne,
2552 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002553 TosaErrorValidator.evConvOutputShapeMismatch,
2554 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002555 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002556 "template": True,
2557 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002558 # Templated operator. Filled in by createDynamicOpLists
2559 "conv3d_TEMPLATE": {
2560 "op": Op.CONV3D,
2561 "operands": (1, 2),
2562 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002563 "build_fcn": (
2564 build_conv3d,
2565 TosaTensorGen.tgConv3D,
2566 TosaTensorValuesGen.tvgDefault,
2567 TosaArgGen.agConv,
2568 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002569 "qgen": TosaQuantGen.qgConv,
2570 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002571 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2572 "error_if_validators": (
2573 TosaErrorValidator.evWrongInputType,
2574 TosaErrorValidator.evWrongOutputType,
2575 TosaErrorValidator.evWrongInputList,
2576 TosaErrorValidator.evWrongOutputList,
2577 TosaErrorValidator.evInputZeroPointNotZero,
2578 TosaErrorValidator.evWeightZeroPointNotZero,
2579 TosaErrorValidator.evPadSmallerZero,
2580 TosaErrorValidator.evStrideSmallerOne,
2581 TosaErrorValidator.evDilationSmallerOne,
2582 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002583 TosaErrorValidator.evConvOutputShapeMismatch,
2584 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002585 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002586 "template": True,
2587 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002588 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002589 "depthwise_conv2d_TEMPLATE": {
2590 "op": Op.DEPTHWISE_CONV2D,
2591 "operands": (1, 2),
2592 "filter": [1, 1],
2593 "rank": (4, 4),
2594 "build_fcn": (
2595 build_depthwise_conv2d,
2596 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002597 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002598 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002599 ),
2600 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002601 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002602 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2603 "error_if_validators": (
2604 TosaErrorValidator.evWrongInputType,
2605 TosaErrorValidator.evWrongOutputType,
2606 TosaErrorValidator.evWrongInputList,
2607 TosaErrorValidator.evWrongOutputList,
2608 TosaErrorValidator.evInputZeroPointNotZero,
2609 TosaErrorValidator.evWeightZeroPointNotZero,
2610 TosaErrorValidator.evPadSmallerZero,
2611 TosaErrorValidator.evStrideSmallerOne,
2612 TosaErrorValidator.evDilationSmallerOne,
2613 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002614 TosaErrorValidator.evConvOutputShapeMismatch,
2615 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002616 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002617 "template": True,
2618 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002619 "fully_connected": {
2620 "op": Op.FULLY_CONNECTED,
2621 "operands": (1, 2),
2622 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002623 "build_fcn": (
2624 build_fully_connected,
2625 TosaTensorGen.tgFullyConnected,
2626 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002627 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002628 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002629 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002630 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002631 "error_if_validators": (
2632 TosaErrorValidator.evInputZeroPointNotZero,
2633 TosaErrorValidator.evWeightZeroPointNotZero,
2634 TosaErrorValidator.evWrongRank,
2635 TosaErrorValidator.evWrongInputType,
2636 TosaErrorValidator.evWrongOutputType,
2637 TosaErrorValidator.evWrongInputList,
2638 TosaErrorValidator.evWrongOutputList,
2639 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002640 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002641 "matmul": {
2642 "op": Op.MATMUL,
2643 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002644 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002645 "build_fcn": (
2646 build_matmul,
2647 TosaTensorGen.tgMatmul,
2648 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002649 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002650 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002651 "qgen": TosaQuantGen.qgMatmul,
2652 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002653 "error_if_validators": (
2654 TosaErrorValidator.evInputZeroPointNotZero,
2655 TosaErrorValidator.evWrongRank,
2656 TosaErrorValidator.evWrongInputType,
2657 TosaErrorValidator.evWrongOutputType,
2658 TosaErrorValidator.evWrongInputList,
2659 TosaErrorValidator.evWrongOutputList,
2660 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002661 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002662 "max_pool2d": {
2663 "op": Op.MAX_POOL2D,
2664 "operands": (1, 0),
2665 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002666 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002667 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002668 TosaTensorGen.tgNHWC,
2669 TosaTensorValuesGen.tvgDefault,
2670 TosaArgGen.agPooling,
2671 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002672 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002673 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002674 "error_if_validators": (
2675 TosaErrorValidator.evKernelSmallerOne,
2676 TosaErrorValidator.evStrideSmallerOne,
2677 TosaErrorValidator.evPadSmallerZero,
2678 TosaErrorValidator.evWrongRank,
2679 TosaErrorValidator.evWrongInputType,
2680 TosaErrorValidator.evWrongOutputType,
2681 TosaErrorValidator.evWrongInputList,
2682 TosaErrorValidator.evWrongOutputList,
2683 TosaErrorValidator.evPadLargerEqualKernel,
2684 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002685 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002686 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002687 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002688 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002689 "transpose_conv2d_TEMPLATE": {
2690 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002691 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002692 "rank": (4, 4),
2693 "build_fcn": (
2694 build_transpose_conv2d,
2695 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002696 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002697 TosaArgGen.agTransposeConv2D,
2698 ),
2699 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002700 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002701 "invalid_test_validators": (
2702 TosaInvalidValidator.ivHeightWidthInvalid,
2703 TosaInvalidValidator.ivNonPositiveOutputShape,
2704 ),
2705 "error_if_validators": (
2706 TosaErrorValidator.evWrongInputType,
2707 TosaErrorValidator.evWrongOutputType,
2708 TosaErrorValidator.evWrongInputList,
2709 TosaErrorValidator.evWrongOutputList,
2710 TosaErrorValidator.evInputZeroPointNotZero,
2711 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002712 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002713 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002714 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002715 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002716 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002717 "template": True,
2718 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002719 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002720 "clamp": {
2721 "op": Op.CLAMP,
2722 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002723 "build_fcn": (
2724 build_clamp,
2725 TosaTensorGen.tgBasic,
2726 TosaTensorValuesGen.tvgDefault,
2727 None,
2728 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002729 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002730 "error_if_validators": (
2731 TosaErrorValidator.evMaxSmallerMin,
2732 TosaErrorValidator.evWrongInputType,
2733 TosaErrorValidator.evWrongOutputType,
2734 TosaErrorValidator.evWrongInputList,
2735 TosaErrorValidator.evWrongOutputList,
2736 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002737 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002738 "sigmoid": {
2739 "op": Op.SIGMOID,
2740 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002741 "build_fcn": (
2742 build_sigmoid,
2743 TosaTensorGen.tgBasic,
2744 TosaTensorValuesGen.tvgDefault,
2745 None,
2746 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002747 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002748 "error_if_validators": (
2749 TosaErrorValidator.evWrongInputType,
2750 TosaErrorValidator.evWrongOutputType,
2751 TosaErrorValidator.evWrongInputList,
2752 TosaErrorValidator.evWrongOutputList,
2753 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002754 },
2755 "tanh": {
2756 "op": Op.TANH,
2757 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002758 "build_fcn": (
2759 build_tanh,
2760 TosaTensorGen.tgBasic,
2761 TosaTensorValuesGen.tvgDefault,
2762 None,
2763 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002764 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002765 "error_if_validators": (
2766 TosaErrorValidator.evWrongInputType,
2767 TosaErrorValidator.evWrongOutputType,
2768 TosaErrorValidator.evWrongInputList,
2769 TosaErrorValidator.evWrongOutputList,
2770 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002771 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002772 # Elementwise Binary Operators
2773 "add": {
2774 "op": Op.ADD,
2775 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002776 "build_fcn": (
2777 build_binary_broadcast,
2778 TosaTensorGen.tgBroadcastFuzz,
2779 TosaTensorValuesGen.tvgAddSub,
2780 None,
2781 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002782 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002783 "error_if_validators": (
2784 TosaErrorValidator.evRankMismatch,
2785 TosaErrorValidator.evWrongInputType,
2786 TosaErrorValidator.evWrongOutputType,
2787 TosaErrorValidator.evWrongInputList,
2788 TosaErrorValidator.evWrongOutputList,
2789 TosaErrorValidator.evDimensionMismatch,
2790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002792 "arithmetic_right_shift": {
2793 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2794 "operands": (2, 0),
2795 "build_fcn": (
2796 build_arithmetic_right_shift,
2797 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002798 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002799 TosaArgGen.agArithmeticRightShift,
2800 ),
2801 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002802 "error_if_validators": (
2803 TosaErrorValidator.evRankMismatch,
2804 TosaErrorValidator.evWrongInputType,
2805 TosaErrorValidator.evWrongOutputType,
2806 TosaErrorValidator.evWrongInputList,
2807 TosaErrorValidator.evWrongOutputList,
2808 TosaErrorValidator.evDimensionMismatch,
2809 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002810 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002811 "bitwise_and": {
2812 "op": Op.BITWISE_AND,
2813 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002814 "build_fcn": (
2815 build_binary_broadcast,
2816 TosaTensorGen.tgBroadcastFuzz,
2817 TosaTensorValuesGen.tvgDefault,
2818 None,
2819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002820 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002821 "error_if_validators": (
2822 TosaErrorValidator.evRankMismatch,
2823 TosaErrorValidator.evWrongInputType,
2824 TosaErrorValidator.evWrongOutputType,
2825 TosaErrorValidator.evWrongInputList,
2826 TosaErrorValidator.evWrongOutputList,
2827 TosaErrorValidator.evDimensionMismatch,
2828 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002829 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002830 "bitwise_or": {
2831 "op": Op.BITWISE_OR,
2832 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002833 "build_fcn": (
2834 build_binary_broadcast,
2835 TosaTensorGen.tgBroadcastFuzz,
2836 TosaTensorValuesGen.tvgDefault,
2837 None,
2838 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002839 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002840 "error_if_validators": (
2841 TosaErrorValidator.evRankMismatch,
2842 TosaErrorValidator.evWrongInputType,
2843 TosaErrorValidator.evWrongOutputType,
2844 TosaErrorValidator.evWrongInputList,
2845 TosaErrorValidator.evWrongOutputList,
2846 TosaErrorValidator.evDimensionMismatch,
2847 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002848 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002849 "bitwise_xor": {
2850 "op": Op.BITWISE_XOR,
2851 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002852 "build_fcn": (
2853 build_binary_broadcast,
2854 TosaTensorGen.tgBroadcastFuzz,
2855 TosaTensorValuesGen.tvgDefault,
2856 None,
2857 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002858 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002859 "error_if_validators": (
2860 TosaErrorValidator.evRankMismatch,
2861 TosaErrorValidator.evWrongInputType,
2862 TosaErrorValidator.evWrongOutputType,
2863 TosaErrorValidator.evWrongInputList,
2864 TosaErrorValidator.evWrongOutputList,
2865 TosaErrorValidator.evDimensionMismatch,
2866 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002867 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002868 "intdiv": {
2869 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002870 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002871 "build_fcn": (
2872 build_binary_broadcast,
2873 TosaTensorGen.tgBroadcastFuzz,
2874 TosaTensorValuesGen.tvgIntDiv,
2875 None,
2876 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002877 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002878 "error_if_validators": (
2879 TosaErrorValidator.evRankMismatch,
2880 TosaErrorValidator.evWrongInputType,
2881 TosaErrorValidator.evWrongOutputType,
2882 TosaErrorValidator.evWrongInputList,
2883 TosaErrorValidator.evWrongOutputList,
2884 TosaErrorValidator.evDimensionMismatch,
2885 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002886 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002887 "logical_and": {
2888 "op": Op.LOGICAL_AND,
2889 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002890 "build_fcn": (
2891 build_binary_broadcast,
2892 TosaTensorGen.tgBroadcastFuzz,
2893 TosaTensorValuesGen.tvgDefault,
2894 None,
2895 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002896 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002897 "error_if_validators": (
2898 TosaErrorValidator.evRankMismatch,
2899 TosaErrorValidator.evWrongInputType,
2900 TosaErrorValidator.evWrongOutputType,
2901 TosaErrorValidator.evWrongInputList,
2902 TosaErrorValidator.evWrongOutputList,
2903 TosaErrorValidator.evDimensionMismatch,
2904 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002905 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002906 "logical_left_shift": {
2907 "op": Op.LOGICAL_LEFT_SHIFT,
2908 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002909 "build_fcn": (
2910 build_binary_broadcast,
2911 TosaTensorGen.tgBroadcastFuzz,
2912 TosaTensorValuesGen.tvgLogicalShift,
2913 None,
2914 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002915 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002916 "error_if_validators": (
2917 TosaErrorValidator.evRankMismatch,
2918 TosaErrorValidator.evWrongInputType,
2919 TosaErrorValidator.evWrongOutputType,
2920 TosaErrorValidator.evWrongInputList,
2921 TosaErrorValidator.evWrongOutputList,
2922 TosaErrorValidator.evDimensionMismatch,
2923 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002924 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002925 "logical_right_shift": {
2926 "op": Op.LOGICAL_RIGHT_SHIFT,
2927 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002928 "build_fcn": (
2929 build_binary_broadcast,
2930 TosaTensorGen.tgBroadcastFuzz,
2931 TosaTensorValuesGen.tvgLogicalShift,
2932 None,
2933 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002934 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002935 "error_if_validators": (
2936 TosaErrorValidator.evRankMismatch,
2937 TosaErrorValidator.evWrongInputType,
2938 TosaErrorValidator.evWrongOutputType,
2939 TosaErrorValidator.evWrongInputList,
2940 TosaErrorValidator.evWrongOutputList,
2941 TosaErrorValidator.evDimensionMismatch,
2942 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002944 "logical_or": {
2945 "op": Op.LOGICAL_OR,
2946 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002947 "build_fcn": (
2948 build_binary_broadcast,
2949 TosaTensorGen.tgBroadcastFuzz,
2950 TosaTensorValuesGen.tvgDefault,
2951 None,
2952 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002953 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002954 "error_if_validators": (
2955 TosaErrorValidator.evRankMismatch,
2956 TosaErrorValidator.evWrongInputType,
2957 TosaErrorValidator.evWrongOutputType,
2958 TosaErrorValidator.evWrongInputList,
2959 TosaErrorValidator.evWrongOutputList,
2960 TosaErrorValidator.evDimensionMismatch,
2961 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002962 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002963 "logical_xor": {
2964 "op": Op.LOGICAL_XOR,
2965 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002966 "build_fcn": (
2967 build_binary_broadcast,
2968 TosaTensorGen.tgBroadcastFuzz,
2969 TosaTensorValuesGen.tvgDefault,
2970 None,
2971 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002972 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002973 "error_if_validators": (
2974 TosaErrorValidator.evRankMismatch,
2975 TosaErrorValidator.evWrongInputType,
2976 TosaErrorValidator.evWrongOutputType,
2977 TosaErrorValidator.evWrongInputList,
2978 TosaErrorValidator.evWrongOutputList,
2979 TosaErrorValidator.evDimensionMismatch,
2980 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002981 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002982 "maximum": {
2983 "op": Op.MAXIMUM,
2984 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002985 "build_fcn": (
2986 build_binary_broadcast,
2987 TosaTensorGen.tgBroadcastFuzz,
2988 TosaTensorValuesGen.tvgDefault,
2989 None,
2990 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002991 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002992 "error_if_validators": (
2993 TosaErrorValidator.evRankMismatch,
2994 TosaErrorValidator.evWrongInputType,
2995 TosaErrorValidator.evWrongOutputType,
2996 TosaErrorValidator.evWrongInputList,
2997 TosaErrorValidator.evWrongOutputList,
2998 TosaErrorValidator.evDimensionMismatch,
2999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003000 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003001 "minimum": {
3002 "op": Op.MINIMUM,
3003 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003004 "build_fcn": (
3005 build_binary_broadcast,
3006 TosaTensorGen.tgBroadcastFuzz,
3007 TosaTensorValuesGen.tvgDefault,
3008 None,
3009 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003010 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003011 "error_if_validators": (
3012 TosaErrorValidator.evRankMismatch,
3013 TosaErrorValidator.evWrongInputType,
3014 TosaErrorValidator.evWrongOutputType,
3015 TosaErrorValidator.evWrongInputList,
3016 TosaErrorValidator.evWrongOutputList,
3017 TosaErrorValidator.evDimensionMismatch,
3018 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003019 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003020 "mul": {
3021 "op": Op.MUL,
3022 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003023 "build_fcn": (
3024 build_mul,
3025 TosaTensorGen.tgBroadcastFuzz,
3026 TosaTensorValuesGen.tvgMul,
3027 TosaArgGen.agMul,
3028 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003029 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003030 "error_if_validators": (
3031 TosaErrorValidator.evWrongInputType,
3032 TosaErrorValidator.evWrongOutputType,
3033 TosaErrorValidator.evWrongInputList,
3034 TosaErrorValidator.evWrongOutputList,
3035 TosaErrorValidator.evRankMismatch,
3036 TosaErrorValidator.evDimensionMismatch,
3037 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003038 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003039 "pow": {
3040 "op": Op.POW,
3041 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003042 "build_fcn": (
3043 build_binary_broadcast,
3044 TosaTensorGen.tgBroadcastFuzz,
3045 TosaTensorValuesGen.tvgDefault,
3046 None,
3047 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003048 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003049 "error_if_validators": (
3050 TosaErrorValidator.evRankMismatch,
3051 TosaErrorValidator.evWrongInputType,
3052 TosaErrorValidator.evWrongOutputType,
3053 TosaErrorValidator.evWrongInputList,
3054 TosaErrorValidator.evWrongOutputList,
3055 TosaErrorValidator.evDimensionMismatch,
3056 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003057 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003058 "sub": {
3059 "op": Op.SUB,
3060 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003061 "build_fcn": (
3062 build_binary_broadcast,
3063 TosaTensorGen.tgBroadcastFuzz,
3064 TosaTensorValuesGen.tvgAddSub,
3065 None,
3066 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003067 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003068 "error_if_validators": (
3069 TosaErrorValidator.evRankMismatch,
3070 TosaErrorValidator.evWrongInputType,
3071 TosaErrorValidator.evWrongOutputType,
3072 TosaErrorValidator.evWrongInputList,
3073 TosaErrorValidator.evWrongOutputList,
3074 TosaErrorValidator.evDimensionMismatch,
3075 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003076 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003077 "table": {
3078 "op": Op.TABLE,
3079 # Use the automatic generation functions to create the input array
3080 # but create the table tensor in the build function, as it may be
3081 # a different type from the input
3082 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003083 "build_fcn": (
3084 build_table,
3085 TosaTensorGen.tgBasic,
3086 TosaTensorValuesGen.tvgDefault,
3087 TosaArgGen.agTable,
3088 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003089 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003090 "error_if_validators": (
3091 TosaErrorValidator.evWrongInputType,
3092 TosaErrorValidator.evWrongOutputType,
3093 TosaErrorValidator.evWrongInputList,
3094 TosaErrorValidator.evWrongOutputList,
3095 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003096 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003097 # Elementwise Unary operators
3098 "abs": {
3099 "op": Op.ABS,
3100 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003101 "build_fcn": (
3102 build_unary,
3103 TosaTensorGen.tgBasic,
3104 TosaTensorValuesGen.tvgDefault,
3105 None,
3106 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003107 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003108 "error_if_validators": (
3109 TosaErrorValidator.evWrongInputType,
3110 TosaErrorValidator.evWrongOutputType,
3111 TosaErrorValidator.evWrongInputList,
3112 TosaErrorValidator.evWrongOutputList,
3113 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003114 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003115 "bitwise_not": {
3116 "op": Op.BITWISE_NOT,
3117 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003118 "build_fcn": (
3119 build_unary,
3120 TosaTensorGen.tgBasic,
3121 TosaTensorValuesGen.tvgDefault,
3122 None,
3123 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003124 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003125 "error_if_validators": (
3126 TosaErrorValidator.evWrongInputType,
3127 TosaErrorValidator.evWrongOutputType,
3128 TosaErrorValidator.evWrongInputList,
3129 TosaErrorValidator.evWrongOutputList,
3130 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003131 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003132 "ceil": {
3133 "op": Op.CEIL,
3134 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003135 "build_fcn": (
3136 build_unary,
3137 TosaTensorGen.tgBasic,
3138 TosaTensorValuesGen.tvgDefault,
3139 None,
3140 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003141 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003142 "error_if_validators": (
3143 TosaErrorValidator.evWrongInputType,
3144 TosaErrorValidator.evWrongOutputType,
3145 TosaErrorValidator.evWrongInputList,
3146 TosaErrorValidator.evWrongOutputList,
3147 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003148 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003149 "clz": {
3150 "op": Op.CLZ,
3151 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003152 "build_fcn": (
3153 build_unary,
3154 TosaTensorGen.tgBasic,
3155 TosaTensorValuesGen.tvgDefault,
3156 None,
3157 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003158 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003159 "error_if_validators": (
3160 TosaErrorValidator.evWrongInputType,
3161 TosaErrorValidator.evWrongOutputType,
3162 TosaErrorValidator.evWrongInputList,
3163 TosaErrorValidator.evWrongOutputList,
3164 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003165 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003166 "exp": {
3167 "op": Op.EXP,
3168 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003169 "build_fcn": (
3170 build_unary,
3171 TosaTensorGen.tgBasic,
3172 TosaTensorValuesGen.tvgDefault,
3173 None,
3174 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003175 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003176 "error_if_validators": (
3177 TosaErrorValidator.evWrongInputType,
3178 TosaErrorValidator.evWrongOutputType,
3179 TosaErrorValidator.evWrongInputList,
3180 TosaErrorValidator.evWrongOutputList,
3181 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003183 "floor": {
3184 "op": Op.FLOOR,
3185 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003186 "build_fcn": (
3187 build_unary,
3188 TosaTensorGen.tgBasic,
3189 TosaTensorValuesGen.tvgDefault,
3190 None,
3191 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003192 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003193 "error_if_validators": (
3194 TosaErrorValidator.evWrongInputType,
3195 TosaErrorValidator.evWrongOutputType,
3196 TosaErrorValidator.evWrongInputList,
3197 TosaErrorValidator.evWrongOutputList,
3198 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003200 "log": {
3201 "op": Op.LOG,
3202 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003203 "build_fcn": (
3204 build_unary,
3205 TosaTensorGen.tgBasic,
3206 TosaTensorValuesGen.tvgDefault,
3207 None,
3208 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003209 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003210 "error_if_validators": (
3211 TosaErrorValidator.evWrongInputType,
3212 TosaErrorValidator.evWrongOutputType,
3213 TosaErrorValidator.evWrongInputList,
3214 TosaErrorValidator.evWrongOutputList,
3215 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003216 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003217 "logical_not": {
3218 "op": Op.LOGICAL_NOT,
3219 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 "build_fcn": (
3221 build_unary,
3222 TosaTensorGen.tgBasic,
3223 TosaTensorValuesGen.tvgDefault,
3224 None,
3225 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003226 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003227 "error_if_validators": (
3228 TosaErrorValidator.evWrongInputType,
3229 TosaErrorValidator.evWrongOutputType,
3230 TosaErrorValidator.evWrongInputList,
3231 TosaErrorValidator.evWrongOutputList,
3232 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003233 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003234 "negate": {
3235 "op": Op.NEGATE,
3236 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003237 "build_fcn": (
3238 build_unary,
3239 TosaTensorGen.tgBasic,
3240 TosaTensorValuesGen.tvgNegate,
3241 None,
3242 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 "qgen": TosaQuantGen.qgUnary,
3244 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003245 "error_if_validators": (
3246 TosaErrorValidator.evInputZeroPointNotZero,
3247 TosaErrorValidator.evOutputZeroPointNotZero,
3248 TosaErrorValidator.evWrongInputType,
3249 TosaErrorValidator.evWrongOutputType,
3250 TosaErrorValidator.evWrongInputList,
3251 TosaErrorValidator.evWrongOutputList,
3252 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003253 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003254 "reciprocal": {
3255 "op": Op.RECIPROCAL,
3256 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003257 "build_fcn": (
3258 build_unary,
3259 TosaTensorGen.tgBasic,
3260 TosaTensorValuesGen.tvgDefault,
3261 None,
3262 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003264 "error_if_validators": (
3265 TosaErrorValidator.evWrongInputType,
3266 TosaErrorValidator.evWrongOutputType,
3267 TosaErrorValidator.evWrongInputList,
3268 TosaErrorValidator.evWrongOutputList,
3269 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003270 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003271 "rsqrt": {
3272 "op": Op.RSQRT,
3273 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003274 "build_fcn": (
3275 build_unary,
3276 TosaTensorGen.tgBasic,
3277 TosaTensorValuesGen.tvgDefault,
3278 None,
3279 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003280 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003281 "error_if_validators": (
3282 TosaErrorValidator.evWrongInputType,
3283 TosaErrorValidator.evWrongOutputType,
3284 TosaErrorValidator.evWrongInputList,
3285 TosaErrorValidator.evWrongOutputList,
3286 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003287 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003288 # Elementwise Ternary operators
3289 "select": {
3290 "op": Op.SELECT,
3291 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003292 "build_fcn": (
3293 build_select,
3294 TosaTensorGen.tgBroadcastFuzz,
3295 TosaTensorValuesGen.tvgSelect,
3296 None,
3297 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003298 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003299 "error_if_validators": (
3300 TosaErrorValidator.evRankMismatch,
3301 TosaErrorValidator.evWrongInputType,
3302 TosaErrorValidator.evWrongOutputType,
3303 TosaErrorValidator.evWrongInputList,
3304 TosaErrorValidator.evWrongOutputList,
3305 TosaErrorValidator.evDimensionMismatch,
3306 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003307 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003308 # Comparison operators
3309 "equal": {
3310 "op": Op.EQUAL,
3311 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 "build_fcn": (
3313 build_comparison,
3314 TosaTensorGen.tgBroadcastFuzz,
3315 TosaTensorValuesGen.tvgEqual,
3316 None,
3317 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003318 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003319 "error_if_validators": (
3320 TosaErrorValidator.evRankMismatch,
3321 TosaErrorValidator.evWrongInputType,
3322 TosaErrorValidator.evWrongOutputType,
3323 TosaErrorValidator.evWrongInputList,
3324 TosaErrorValidator.evWrongOutputList,
3325 TosaErrorValidator.evDimensionMismatch,
3326 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003327 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003328 "greater_equal": {
3329 "op": Op.GREATER_EQUAL,
3330 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003331 "build_fcn": (
3332 build_comparison,
3333 TosaTensorGen.tgBroadcastFuzz,
3334 TosaTensorValuesGen.tvgDefault,
3335 None,
3336 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003338 "error_if_validators": (
3339 TosaErrorValidator.evRankMismatch,
3340 TosaErrorValidator.evWrongInputType,
3341 TosaErrorValidator.evWrongOutputType,
3342 TosaErrorValidator.evWrongInputList,
3343 TosaErrorValidator.evWrongOutputList,
3344 TosaErrorValidator.evDimensionMismatch,
3345 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003346 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003347 "greater": {
3348 "op": Op.GREATER,
3349 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003350 "build_fcn": (
3351 build_comparison,
3352 TosaTensorGen.tgBroadcastFuzz,
3353 TosaTensorValuesGen.tvgDefault,
3354 None,
3355 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003356 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003357 "error_if_validators": (
3358 TosaErrorValidator.evRankMismatch,
3359 TosaErrorValidator.evWrongInputType,
3360 TosaErrorValidator.evWrongOutputType,
3361 TosaErrorValidator.evWrongInputList,
3362 TosaErrorValidator.evWrongOutputList,
3363 TosaErrorValidator.evDimensionMismatch,
3364 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003365 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003366 # Reduction operators
3367 "reduce_all": {
3368 "op": Op.REDUCE_ALL,
3369 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003370 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003371 "build_fcn": (
3372 build_reduce,
3373 TosaTensorGen.tgBasic,
3374 TosaTensorValuesGen.tvgDefault,
3375 TosaArgGen.agAxis,
3376 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003377 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003378 "error_if_validators": (
3379 TosaErrorValidator.evAxisLargerRank,
3380 TosaErrorValidator.evAxisSmallerZero,
3381 TosaErrorValidator.evShapeOfAxisNotOne,
3382 TosaErrorValidator.evWrongInputType,
3383 TosaErrorValidator.evWrongOutputType,
3384 TosaErrorValidator.evWrongRank,
3385 TosaErrorValidator.evWrongInputList,
3386 TosaErrorValidator.evWrongOutputList,
3387 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003388 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003389 "reduce_any": {
3390 "op": Op.REDUCE_ANY,
3391 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003392 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003393 "build_fcn": (
3394 build_reduce,
3395 TosaTensorGen.tgBasic,
3396 TosaTensorValuesGen.tvgDefault,
3397 TosaArgGen.agAxis,
3398 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003399 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003400 "error_if_validators": (
3401 TosaErrorValidator.evAxisLargerRank,
3402 TosaErrorValidator.evAxisSmallerZero,
3403 TosaErrorValidator.evShapeOfAxisNotOne,
3404 TosaErrorValidator.evWrongInputType,
3405 TosaErrorValidator.evWrongOutputType,
3406 TosaErrorValidator.evWrongRank,
3407 TosaErrorValidator.evWrongInputList,
3408 TosaErrorValidator.evWrongOutputList,
3409 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003410 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003411 "reduce_max": {
3412 "op": Op.REDUCE_MAX,
3413 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003414 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003415 "build_fcn": (
3416 build_reduce,
3417 TosaTensorGen.tgBasic,
3418 TosaTensorValuesGen.tvgDefault,
3419 TosaArgGen.agAxis,
3420 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003421 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003422 "error_if_validators": (
3423 TosaErrorValidator.evAxisLargerRank,
3424 TosaErrorValidator.evAxisSmallerZero,
3425 TosaErrorValidator.evShapeOfAxisNotOne,
3426 TosaErrorValidator.evWrongInputType,
3427 TosaErrorValidator.evWrongOutputType,
3428 TosaErrorValidator.evWrongRank,
3429 TosaErrorValidator.evWrongInputList,
3430 TosaErrorValidator.evWrongOutputList,
3431 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003432 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003434 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003435 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003436 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003437 "build_fcn": (
3438 build_reduce,
3439 TosaTensorGen.tgBasic,
3440 TosaTensorValuesGen.tvgDefault,
3441 TosaArgGen.agAxis,
3442 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003443 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003444 "error_if_validators": (
3445 TosaErrorValidator.evAxisLargerRank,
3446 TosaErrorValidator.evAxisSmallerZero,
3447 TosaErrorValidator.evShapeOfAxisNotOne,
3448 TosaErrorValidator.evWrongInputType,
3449 TosaErrorValidator.evWrongOutputType,
3450 TosaErrorValidator.evWrongRank,
3451 TosaErrorValidator.evWrongInputList,
3452 TosaErrorValidator.evWrongOutputList,
3453 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003454 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 "reduce_product": {
3456 "op": Op.REDUCE_PRODUCT,
3457 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003458 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003459 "build_fcn": (
3460 build_reduce,
3461 TosaTensorGen.tgBasic,
3462 TosaTensorValuesGen.tvgDefault,
3463 TosaArgGen.agAxis,
3464 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003465 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003466 "error_if_validators": (
3467 TosaErrorValidator.evAxisLargerRank,
3468 TosaErrorValidator.evAxisSmallerZero,
3469 TosaErrorValidator.evShapeOfAxisNotOne,
3470 TosaErrorValidator.evWrongInputType,
3471 TosaErrorValidator.evWrongOutputType,
3472 TosaErrorValidator.evWrongRank,
3473 TosaErrorValidator.evWrongInputList,
3474 TosaErrorValidator.evWrongOutputList,
3475 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003476 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003477 "reduce_sum": {
3478 "op": Op.REDUCE_SUM,
3479 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003480 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003481 "build_fcn": (
3482 build_reduce,
3483 TosaTensorGen.tgBasic,
3484 TosaTensorValuesGen.tvgReduceSum,
3485 TosaArgGen.agAxis,
3486 ),
James Ward24dbc422022-10-19 12:20:31 +01003487 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003488 "error_if_validators": (
3489 TosaErrorValidator.evAxisLargerRank,
3490 TosaErrorValidator.evAxisSmallerZero,
3491 TosaErrorValidator.evShapeOfAxisNotOne,
3492 TosaErrorValidator.evWrongInputType,
3493 TosaErrorValidator.evWrongOutputType,
3494 TosaErrorValidator.evWrongRank,
3495 TosaErrorValidator.evWrongInputList,
3496 TosaErrorValidator.evWrongOutputList,
3497 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003499 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003500 "concat": {
3501 "op": Op.CONCAT,
3502 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003503 "build_fcn": (
3504 build_concat,
3505 TosaTensorGen.tgConcat,
3506 TosaTensorValuesGen.tvgConcat,
3507 TosaArgGen.agAxis,
3508 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003509 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003510 "error_if_validators": (
3511 TosaErrorValidator.evAxisLargerRank,
3512 TosaErrorValidator.evAxisSmallerZero,
3513 TosaErrorValidator.evConcatInputRankMismatch,
3514 TosaErrorValidator.evConcatShapeSumMismatch,
3515 TosaErrorValidator.evConcatInputDimMismatch,
3516 TosaErrorValidator.evWrongInputType,
3517 TosaErrorValidator.evWrongOutputType,
3518 TosaErrorValidator.evWrongOutputList,
3519 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003520 },
3521 "pad": {
3522 "op": Op.PAD,
3523 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003524 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003525 "build_fcn": (
3526 build_pad,
3527 TosaTensorGen.tgBasic,
3528 TosaTensorValuesGen.tvgDefault,
3529 TosaArgGen.agPad,
3530 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003531 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003532 "error_if_validators": (
3533 TosaErrorValidator.evWrongInputType,
3534 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003535 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003536 TosaErrorValidator.evWrongOutputType,
3537 TosaErrorValidator.evWrongInputList,
3538 TosaErrorValidator.evWrongOutputList,
3539 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003540 },
3541 "reshape": {
3542 "op": Op.RESHAPE,
3543 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003544 "build_fcn": (
3545 build_reshape,
3546 TosaTensorGen.tgBasic,
3547 TosaTensorValuesGen.tvgDefault,
3548 TosaArgGen.agReshape,
3549 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003550 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003551 "error_if_validators": (
3552 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3553 TosaErrorValidator.evWrongInputType,
3554 TosaErrorValidator.evWrongOutputType,
3555 TosaErrorValidator.evWrongInputList,
3556 TosaErrorValidator.evWrongOutputList,
3557 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003558 },
3559 "reverse": {
3560 "op": Op.REVERSE,
3561 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003562 "build_fcn": (
3563 build_reverse,
3564 TosaTensorGen.tgBasic,
3565 TosaTensorValuesGen.tvgDefault,
3566 TosaArgGen.agAxis,
3567 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003568 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003569 "error_if_validators": (
3570 TosaErrorValidator.evAxisSmallerZero,
3571 TosaErrorValidator.evAxisLargerRank,
3572 TosaErrorValidator.evWrongInputType,
3573 TosaErrorValidator.evWrongOutputType,
3574 TosaErrorValidator.evWrongInputList,
3575 TosaErrorValidator.evWrongOutputList,
3576 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003577 },
3578 "slice": {
3579 "op": Op.SLICE,
3580 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003581 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003582 "build_fcn": (
3583 build_slice,
3584 TosaTensorGen.tgBasic,
3585 TosaTensorValuesGen.tvgDefault,
3586 TosaArgGen.agSlice,
3587 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003588 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003589 "error_if_validators": (
3590 TosaErrorValidator.evStartSmallerZero,
3591 TosaErrorValidator.evSizeSmallerEqualZero,
3592 TosaErrorValidator.evStartSizeOutsideBounds,
3593 TosaErrorValidator.evSizeOutputShapeMismatch,
3594 TosaErrorValidator.evInputSizeStartLengthMismatch,
3595 TosaErrorValidator.evWrongRank,
3596 TosaErrorValidator.evWrongInputType,
3597 TosaErrorValidator.evWrongOutputType,
3598 TosaErrorValidator.evWrongInputList,
3599 TosaErrorValidator.evWrongOutputList,
3600 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003601 },
3602 "tile": {
3603 "op": Op.TILE,
3604 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003605 "build_fcn": (
3606 build_tile,
3607 TosaTensorGen.tgBasic,
3608 TosaTensorValuesGen.tvgDefault,
3609 TosaArgGen.agTile,
3610 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003611 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003612 "error_if_validators": (
3613 TosaErrorValidator.evWrongInputType,
3614 TosaErrorValidator.evWrongOutputType,
3615 TosaErrorValidator.evWrongInputList,
3616 TosaErrorValidator.evWrongOutputList,
3617 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003618 },
3619 "transpose": {
3620 "op": Op.TRANSPOSE,
3621 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003622 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003623 "build_fcn": (
3624 build_transpose,
3625 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003626 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003627 TosaArgGen.agTranspose,
3628 ),
3629 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003630 "error_if_validators": (
3631 TosaErrorValidator.evIndexOutsideBounds,
3632 TosaErrorValidator.evIndexUsedTwice,
3633 TosaErrorValidator.evWrongInputType,
3634 TosaErrorValidator.evWrongOutputType,
3635 TosaErrorValidator.evWrongInputList,
3636 TosaErrorValidator.evWrongOutputList,
3637 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003638 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003639 # Data nodes
3640 "const": {
3641 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003642 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003643 "build_fcn": (
3644 build_const,
3645 TosaTensorGen.tgBasic,
3646 TosaTensorValuesGen.tvgDefault,
3647 None,
3648 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003649 "types": TYPE_FIB,
3650 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003651 "identity": {
3652 "op": Op.IDENTITY,
3653 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003654 "build_fcn": (
3655 build_unary,
3656 TosaTensorGen.tgBasic,
3657 TosaTensorValuesGen.tvgDefault,
3658 None,
3659 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003660 "types": TYPE_FIB,
3661 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003662 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003663 "gather": {
3664 "op": Op.GATHER,
3665 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3666 "operands": (1, 0),
3667 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003668 "build_fcn": (
3669 build_gather,
3670 TosaTensorGen.tgBasic,
3671 TosaTensorValuesGen.tvgDefault,
3672 None,
3673 ),
James Ward24dbc422022-10-19 12:20:31 +01003674 "types": (
3675 DType.INT8,
3676 DType.INT16,
3677 DType.INT32,
3678 DType.FP16,
3679 DType.BF16,
3680 DType.FP32,
3681 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003682 "error_if_validators": (
3683 TosaErrorValidator.evWrongInputType,
3684 TosaErrorValidator.evWrongOutputType,
3685 TosaErrorValidator.evWrongInputList,
3686 TosaErrorValidator.evWrongOutputList,
3687 TosaErrorValidator.evWrongRank,
3688 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003689 },
3690 "scatter": {
3691 "op": Op.SCATTER,
3692 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003693 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003694 "operands": (2, 0),
3695 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003696 "build_fcn": (
3697 build_scatter,
3698 TosaTensorGen.tgScatter,
3699 TosaTensorValuesGen.tvgDefault,
3700 None,
3701 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003702 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003703 "error_if_validators": (
3704 TosaErrorValidator.evWrongInputType,
3705 TosaErrorValidator.evWrongOutputType,
3706 TosaErrorValidator.evWrongInputList,
3707 TosaErrorValidator.evWrongOutputList,
3708 TosaErrorValidator.evWrongRank,
3709 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003710 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003711 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003712 "resize": {
3713 "op": Op.RESIZE,
3714 "operands": (1, 0),
3715 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003716 "build_fcn": (
3717 build_resize,
3718 TosaTensorGen.tgNHWC,
3719 TosaTensorValuesGen.tvgDefault,
3720 TosaArgGen.agResize,
3721 ),
James Ward24dbc422022-10-19 12:20:31 +01003722 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003723 "invalid_test_validators": (
3724 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003725 ),
3726 "error_if_validators": (
3727 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003728 TosaErrorValidator.evScaleSmallerEqualZero,
3729 TosaErrorValidator.evScaleNLargerMax,
3730 TosaErrorValidator.evScaleDLargerMax,
3731 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003732 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003733 TosaErrorValidator.evBorderSmallerMin,
3734 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003735 TosaErrorValidator.evWrongInputType,
3736 TosaErrorValidator.evWrongOutputType,
3737 TosaErrorValidator.evWrongRank,
3738 TosaErrorValidator.evWrongInputList,
3739 TosaErrorValidator.evWrongOutputList,
3740 TosaErrorValidator.evBatchMismatch,
3741 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003742 TosaErrorValidator.evResizeOutputShapeMismatch,
3743 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003744 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003745 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003746 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003747 "cast": {
3748 "op": Op.CAST,
3749 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003750 "build_fcn": (
3751 build_cast,
3752 TosaTensorGen.tgBasic,
3753 TosaTensorValuesGen.tvgDefault,
3754 TosaArgGen.agCast,
3755 ),
James Ward8b390432022-08-12 20:48:56 +01003756 "types": (
3757 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003758 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003759 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003760 DType.INT8,
3761 DType.INT16,
3762 DType.INT32,
3763 DType.BOOL,
3764 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003765 "error_if_validators": (
3766 TosaErrorValidator.evWrongInputType,
3767 TosaErrorValidator.evWrongOutputType,
3768 TosaErrorValidator.evWrongInputList,
3769 TosaErrorValidator.evWrongOutputList,
3770 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003771 },
3772 "rescale": {
3773 "op": Op.RESCALE,
3774 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003775 "build_fcn": (
3776 build_rescale,
3777 TosaTensorGen.tgBasic,
3778 TosaTensorValuesGen.tvgDefault,
3779 TosaArgGen.agRescale,
3780 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003781 "types": [
3782 DType.UINT8,
3783 DType.INT8,
3784 DType.INT16,
3785 DType.INT32,
3786 DType.INT48,
3787 DType.UINT16,
3788 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003789 "error_if_validators": (
3790 TosaErrorValidator.evInputZeroPointNotZero,
3791 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003792 TosaErrorValidator.evU16InputZeroPointNotValid,
3793 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003794 TosaErrorValidator.evScaleTrue,
3795 TosaErrorValidator.evScaleNotTrue,
3796 TosaErrorValidator.evWrongInputType,
3797 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003798 TosaErrorValidator.evWrongInputList,
3799 TosaErrorValidator.evWrongOutputList,
3800 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003801 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003802 # Custom
3803 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003804 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003805 # Two varients of cond_if, one that generates one of two constant tensors (no
3806 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3807 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003808 "cond_if_const": {
3809 "op": Op.COND_IF,
3810 "operands": (0, 2),
3811 "build_fcn": (
3812 build_cond_if_const,
3813 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003814 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003815 TosaArgGen.agCondIf,
3816 ),
3817 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003818 "error_if_validators": (
3819 TosaErrorValidator.evOutputListThenGraphMismatch,
3820 TosaErrorValidator.evOutputListElseGraphMismatch,
3821 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003822 },
3823 "cond_if_binary": {
3824 "op": Op.COND_IF,
3825 "operands": (2, 0),
3826 "build_fcn": (
3827 build_cond_if_binary,
3828 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003829 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003830 TosaArgGen.agCondIf,
3831 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003832 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003833 "error_if_validators": (
3834 TosaErrorValidator.evInputListThenGraphMismatch,
3835 TosaErrorValidator.evInputListElseGraphMismatch,
3836 TosaErrorValidator.evOutputListThenGraphMismatch,
3837 TosaErrorValidator.evOutputListElseGraphMismatch,
3838 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003839 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003840 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003841 "while_loop": {
3842 "op": Op.WHILE_LOOP,
3843 "operands": (0, 1),
3844 "build_fcn": (
3845 build_while_loop,
3846 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003847 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003848 TosaArgGen.agWhileLoop,
3849 ),
3850 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 "error_if_validators": (
3852 TosaErrorValidator.evInputListOutputListMismatch,
3853 TosaErrorValidator.evInputListCondGraphMismatch,
3854 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3855 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3856 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
3857 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003858 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003859 }
3860
Kevin Cheng550ccc52021-03-03 11:21:43 -08003861
Eric Kunzee5e26762020-10-13 16:11:07 -07003862class OutputShaper:
3863 # Methods in this class compute the expected output shape and datatype
3864 # for common classes of operations
3865 def __init__(self):
3866 pass
3867
3868 # These methods return arguments that can be used for
3869 # creating a new output tensor
3870 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003871 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3872 if error_name != ErrorIf.RankMismatch:
3873 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003874 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003875
3876 shape = []
3877 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003878 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003879 shape.append(b.shape[i])
3880 else:
3881 shape.append(a.shape[i])
3882
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003883 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003884 all_dtypes = [
3885 DType.INT8,
3886 DType.INT16,
3887 DType.INT32,
3888 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01003889 DType.FP16,
3890 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003891 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003892 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003893 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3894 outputDType = rng.choice(wrong_dtypes)
3895 else:
3896 outputDType = a.dtype
3897
3898 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003899
3900 @staticmethod
3901 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003902 assert len(a.shape) == len(b.shape)
3903 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003904
3905 shape = []
3906 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003907 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003908 shape.append(a.shape[i])
3909
Kevin Cheng550ccc52021-03-03 11:21:43 -08003910 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003911
3912 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003913 def unaryOp(ser, rng, a, error_name=None):
3914 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003915 all_dtypes = [
3916 DType.INT8,
3917 DType.INT16,
3918 DType.INT32,
3919 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003920 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01003921 DType.FP16,
3922 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003923 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003924 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3925 outputDType = rng.choice(wrong_dtypes)
3926 else:
3927 outputDType = a.dtype
3928
3929 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003930
3931 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003932 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003933 if error_name != ErrorIf.RankMismatch:
3934 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003935 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003936
3937 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003938 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003940 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3941 else:
3942 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003943
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003944 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003945 all_dtypes = [
3946 DType.INT8,
3947 DType.INT16,
3948 DType.INT32,
3949 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003950 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01003951 DType.FP16,
3952 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003953 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003954 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3955 outputDType = rng.choice(wrong_dtypes)
3956 else:
3957 outputDType = a.dtype
3958
3959 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003960
3961 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003962 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003963 if error_name != ErrorIf.RankMismatch:
3964 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003965 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003966
3967 # Do broadcast
3968 shape = []
3969 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08003970 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07003971 shape.append(b.shape[i])
3972 else:
3973 shape.append(a.shape[i])
3974
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003975 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003976 wrong_dtypes = [
3977 DType.INT8,
3978 DType.INT16,
3979 DType.INT32,
3980 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003981 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01003982 DType.FP16,
3983 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003984 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003985 outputDType = rng.choice(wrong_dtypes)
3986 else:
3987 outputDType = DType.BOOL
3988
3989 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003990
3991 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01003992 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003993 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003994 if error_name not in [
3995 ErrorIf.AxisSmallerZero,
3996 ErrorIf.AxisLargerRank,
3997 ErrorIf.ShapeOfAxisNotOne,
3998 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01003999 shape[axis] = 1
4000 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4001 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004002
Matthew Haddond6ce7252021-09-29 15:35:44 +01004003 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004004 all_dtypes = [
4005 DType.INT8,
4006 DType.INT16,
4007 DType.INT32,
4008 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004009 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004010 DType.FP16,
4011 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004012 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004013 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4014 outputDType = rng.choice(wrong_dtypes)
4015 else:
4016 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004017
Matthew Haddond6ce7252021-09-29 15:35:44 +01004018 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004019
4020 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004021 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004022 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004023
4024 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4025 del shape[axis]
4026
4027 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4028 remove = rng.choice([True, False])
4029 if remove and len(shape) > 1:
4030 del shape[0]
4031 else:
4032 shape.append(1)
4033 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4034 for i in range(len(shape)):
4035 shape[i] = shape[i] + rng.integers(1, 10)
4036
4037 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004038 all_dtypes = [
4039 DType.INT8,
4040 DType.INT16,
4041 DType.INT32,
4042 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004043 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004044 DType.FP16,
4045 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004046 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004047 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4048 outputDType = rng.choice(wrong_dtypes)
4049 else:
4050 outputDType = DType.INT32
4051
4052 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004053
4054 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004055 def conv2dOp(
4056 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4057 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004058
4059 # IFM: NHWC
4060 # Filter: OHWI
4061 # OFM: NHWC
4062
Kevin Cheng550ccc52021-03-03 11:21:43 -08004063 h = (
4064 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004065 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004066 + padding[0]
4067 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004068 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004069 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004070
Kevin Cheng550ccc52021-03-03 11:21:43 -08004071 w = (
4072 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004073 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004074 + padding[2]
4075 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004076 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004077 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004078
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004079 if error_name == ErrorIf.ConvOutputShapeMismatch:
4080 choices = [1, 2, 3]
4081 change = rng.choice(choices)
4082 # increment in multiples of stride to not hit non-integer error case
4083 if change in [1, 3]:
4084 h = h + (rng.choice(choices) * strides[0])
4085 if change in [2, 3]:
4086 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004087
Eric Kunzee5e26762020-10-13 16:11:07 -07004088 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4089
James Ward8b390432022-08-12 20:48:56 +01004090 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004091 # Pick some potentially correct output dtype if input type is incorrect
4092 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004093 else:
James Ward8b390432022-08-12 20:48:56 +01004094 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004095
4096 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004097 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004098 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004099 else:
4100 excludes = [out_dtype]
4101 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004102 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004103
Kevin Cheng550ccc52021-03-03 11:21:43 -08004104 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004105
4106 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004107 def conv3dOp(
4108 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4109 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004110
4111 # IFM: NDHWC
4112 # Filter: ODHWI
4113 # OFM: NDHWC
4114
4115 d = (
4116 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004117 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004118 + padding[0]
4119 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004120 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004121 ) // strides[0] + 1
4122
4123 h = (
4124 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004125 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004126 + padding[2]
4127 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004128 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004129 ) // strides[1] + 1
4130
4131 w = (
4132 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004133 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004134 + padding[4]
4135 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004136 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004137 ) // strides[2] + 1
4138
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004139 if error_name == ErrorIf.ConvOutputShapeMismatch:
4140 choices = [1, 2, 3, 4]
4141 change = rng.choice(choices)
4142 # increment in multiples of stride to not hit non-integer error case
4143 if change in [1, 4]:
4144 d = d + (rng.choice(choices) * strides[0])
4145 if change in [2, 4]:
4146 h = h + (rng.choice(choices) * strides[1])
4147 if change in [3, 4]:
4148 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004149
Kevin Cheng1533b852021-09-01 12:51:58 -07004150 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4151
James Ward8b390432022-08-12 20:48:56 +01004152 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004153 # Pick some potentially correct output dtype if input type is incorrect
4154 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004155 else:
James Ward8b390432022-08-12 20:48:56 +01004156 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004157
4158 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004159 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004160 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004161 else:
4162 excludes = [out_dtype]
4163 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004164 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004165
4166 return ser.addOutput(ofm_shape, out_dtype)
4167
4168 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004169 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004170 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004171 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004172 # IFM: NHWC
4173 # Filter: HWCM
4174 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004175
Kevin Cheng550ccc52021-03-03 11:21:43 -08004176 h = (
4177 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004178 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004179 + padding[0]
4180 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004181 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004182 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004183
Kevin Cheng550ccc52021-03-03 11:21:43 -08004184 w = (
4185 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004186 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004187 + padding[2]
4188 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004189 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004190 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004191
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004192 if error_name == ErrorIf.ConvOutputShapeMismatch:
4193 choices = [1, 2, 3]
4194 change = rng.choice(choices)
4195 # increment in multiples of stride to not hit non-integer error case
4196 if change in [1, 3]:
4197 h = h + (rng.choice(choices) * strides[0])
4198 if change in [2, 3]:
4199 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004200
Eric Kunzee5e26762020-10-13 16:11:07 -07004201 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4202
James Ward8b390432022-08-12 20:48:56 +01004203 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004204 # Pick some potentially correct output dtype if input type is incorrect
4205 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004206 else:
James Ward8b390432022-08-12 20:48:56 +01004207 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004208
4209 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004210 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004211 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004212 else:
4213 excludes = [out_dtype]
4214 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004215 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004216
Kevin Cheng550ccc52021-03-03 11:21:43 -08004217 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004218
4219 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004220 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004221 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004222 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004223 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004224 h = 1
4225 w = 1
4226 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004227 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4228 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004229
4230 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004231 choices = [1, 2, 3]
4232 change = rng.choice(choices)
4233 # increment in multiples of stride to not hit non-integer error case
4234 if change in [1, 3]:
4235 h = h + (rng.choice(choices) * stride[0])
4236 if change in [2, 3]:
4237 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004238 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004239
4240 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004241 all_dtypes = [
4242 DType.INT8,
4243 DType.INT16,
4244 DType.INT32,
4245 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004246 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004247 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004248 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004249 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004250 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4251 outputDType = rng.choice(wrong_dtypes)
4252 else:
4253 outputDType = ifm.dtype
4254
4255 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004256
4257 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004258 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004259 # input: N, IC
4260 # filter: OC, IC
4261 # output: N, OC
4262
4263 output_shape = [input.shape[0], filter.shape[0]]
4264
James Ward8b390432022-08-12 20:48:56 +01004265 # Validated in arg_gen (also invalidated for ErrorIf)
4266 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004267
Kevin Cheng550ccc52021-03-03 11:21:43 -08004268 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004269
4270 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004271 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004272 # a: N, H, C
4273 # b: N, C, W
4274 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004275
Kevin Cheng2d60f002021-06-09 14:18:32 -07004276 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004277
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004278 if error_name == ErrorIf.WrongOutputType:
4279 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004280 incorrect_types = (
4281 DType.INT4,
4282 DType.INT8,
4283 DType.INT16,
4284 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004285 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004286 DType.FP16,
4287 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004288 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004289 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004290 incorrect_types = (
4291 DType.INT4,
4292 DType.INT8,
4293 DType.INT16,
4294 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004295 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004296 DType.FP16,
4297 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004298 )
James Ward24dbc422022-10-19 12:20:31 +01004299 elif (
4300 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4301 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004302 incorrect_types = (
4303 DType.INT4,
4304 DType.INT8,
4305 DType.INT16,
4306 DType.INT32,
4307 DType.INT48,
4308 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004309 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004310 elif error_name == ErrorIf.WrongInputType:
4311 # Pick some potentially correct output dtype if input type is incorrect
4312 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004313 else:
James Ward8b390432022-08-12 20:48:56 +01004314 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004315
Kevin Cheng550ccc52021-03-03 11:21:43 -08004316 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004317
4318 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004319 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004320 input1 = a[0]
4321 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004322
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004323 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004324 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004325 if not (
4326 # unable to concat tensors of different ranks
4327 error_name == ErrorIf.ConcatInputRankMismatch
4328 # unable to concat tensors along an invalid axis
4329 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004330 ):
4331 for tensor in remaining_inputs:
4332 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004333
Matthew Haddon01c359d2021-10-15 16:30:48 +01004334 if error_name == ErrorIf.ConcatShapeSumMismatch:
4335 output_shape[axis] += rng.integers(5, 10)
4336
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004337 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004338 all_dtypes = {
4339 DType.INT8,
4340 DType.INT16,
4341 DType.INT32,
4342 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004343 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004344 DType.FP16,
4345 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004346 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004347 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4348 outputDType = rng.choice(wrong_dtypes)
4349 else:
4350 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004351
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004352 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004353
4354 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004355 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004356
4357 output_shape = a.shape.copy()
4358
4359 for i in range(len(output_shape)):
4360 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4361
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004362 if error_name == ErrorIf.PadOutputShapeMismatch:
4363 bad_dim = rng.choice(range(len(output_shape)))
4364 output_shape[bad_dim] -= rng.choice([1, 2])
4365
Matthew Haddone807aae2021-10-11 18:12:58 +01004366 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004367 all_dtypes = [
4368 DType.INT8,
4369 DType.INT16,
4370 DType.INT32,
4371 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004372 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004373 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004374 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004375 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004376 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4377 outputDType = rng.choice(wrong_dtypes)
4378 else:
4379 outputDType = a.dtype
4380
4381 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004382
4383 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004384 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004385 output_shape = shape.copy()
4386
Matthew Haddone807aae2021-10-11 18:12:58 +01004387 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4388 for i in range(len(output_shape)):
4389 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4390
4391 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004392 all_dtypes = [
4393 DType.INT8,
4394 DType.INT16,
4395 DType.INT32,
4396 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004397 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004398 DType.FP16,
4399 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004400 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004401 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4402 outputDType = rng.choice(wrong_dtypes)
4403 else:
4404 outputDType = a.dtype
4405
4406 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004407
4408 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004409 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004410
Matthew Haddone807aae2021-10-11 18:12:58 +01004411 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004412 all_dtypes = [
4413 DType.INT8,
4414 DType.INT16,
4415 DType.INT32,
4416 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004417 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004418 DType.FP16,
4419 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004420 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004421 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4422 outputDType = rng.choice(wrong_dtypes)
4423 else:
4424 outputDType = a.dtype
4425
4426 if error_name == ErrorIf.SizeOutputShapeMismatch:
4427 output_shape = size.copy()
4428 for index in range(len(output_shape)):
4429 if output_shape[index] <= 2:
4430 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4431 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004432 output_shape[index] = output_shape[index] + rng.choice(
4433 [-2, -1, 1, 2]
4434 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004435 else:
4436 output_shape = size.copy()
4437
4438 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004439
4440 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004441 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004442
4443 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004444 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004445
4446 for i in range(len(output_shape)):
4447 output_shape[i] = a.shape[i] * multiples[i]
4448
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004449 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004450 all_dtypes = [
4451 DType.INT8,
4452 DType.INT16,
4453 DType.INT32,
4454 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004455 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004456 DType.FP16,
4457 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004458 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004459 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4460 outputDType = rng.choice(wrong_dtypes)
4461 else:
4462 outputDType = a.dtype
4463
4464 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004465
4466 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004467 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004468 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004469
Kevin Cheng550ccc52021-03-03 11:21:43 -08004470 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004471
Matthew Haddone807aae2021-10-11 18:12:58 +01004472 if error_name == ErrorIf.IndexOutsideBounds:
4473 for i in range(len(output_shape)):
4474 output_shape[i] = a.shape[0]
4475 else:
4476 for i in range(len(output_shape)):
4477 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004478
Matthew Haddone807aae2021-10-11 18:12:58 +01004479 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004480 all_dtypes = [
4481 DType.INT8,
4482 DType.INT16,
4483 DType.INT32,
4484 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004485 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004486 DType.FP16,
4487 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004488 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004489 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4490 outputDType = rng.choice(wrong_dtypes)
4491 else:
4492 outputDType = a.dtype
4493
4494 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004495
4496 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004497 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004498 if error_name != ErrorIf.WrongRank:
4499 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004500 assert len(indices.shape) == 2
4501 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004502
Kevin Cheng77d0f762020-11-24 10:26:32 -08004503 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4504
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004505 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004506 all_dtypes = [
4507 DType.INT8,
4508 DType.INT16,
4509 DType.INT32,
4510 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004511 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004512 DType.FP16,
4513 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004514 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004515 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4516 outputDType = rng.choice(wrong_dtypes)
4517 else:
4518 outputDType = values.dtype
4519
4520 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004521
4522 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004523 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004524 if error_name != ErrorIf.WrongRank:
4525 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004526 assert len(indices.shape) == 2
4527 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004528 assert values_in.shape[0] == indices.shape[0] # N
4529 assert input.shape[1] == indices.shape[1] # W
4530 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004531
4532 output_shape = values_in.shape
4533
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004534 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004535 all_dtypes = [
4536 DType.INT8,
4537 DType.INT16,
4538 DType.INT32,
4539 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004540 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004541 DType.FP16,
4542 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004543 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004544 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4545 outputDType = rng.choice(wrong_dtypes)
4546 else:
4547 outputDType = values_in.dtype
4548
4549 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004550
4551 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004552 def tableOp(ser, rng, input, error_name=None):
4553 # Same shape as the input, dtype dependent on input dtype
4554 if error_name != ErrorIf.WrongInputType:
4555 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004556 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004557 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004558 wrong_dtypes = [
4559 DType.INT8,
4560 DType.INT16,
4561 DType.INT32,
4562 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004563 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004564 DType.FP16,
4565 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004566 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004567 wrong_dtypes.remove(output_dtype)
4568 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004569 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004570
4571 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004572 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004573 serializer,
4574 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004575 input,
4576 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004577 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004578 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004579 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004580 input_dtype,
4581 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004582 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004583 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004584 # Calculate OH, OW
4585 scale_y_n = scale[0]
4586 scale_y_d = scale[1]
4587 scale_x_n = scale[2]
4588 scale_x_d = scale[3]
4589 if error_name == ErrorIf.ScaleSmallerEqualZero:
4590 scale_y_n = max(scale_y_n, 1)
4591 scale_y_d = max(scale_y_d, 1)
4592 scale_x_n = max(scale_x_n, 1)
4593 scale_x_d = max(scale_x_d, 1)
4594
4595 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4596 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4597
4598 if error_name is not None:
4599 # Make sure the output tensor is valid, which can occur when
4600 # scale, offset or border have been changed for ERROR_IFs
4601 oh = max(oh, 1)
4602 ow = max(ow, 1)
4603 if error_name != ErrorIf.MaxDimExceeded:
4604 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4605 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4606
4607 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4608 choices = [1, 2, 3]
4609 change = rng.choice(choices)
4610 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4611 if change in [1, 3]:
4612 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4613 oh -= scale_y_d
4614 assert oh > 0 # Should have been caught in agResize
4615 else:
4616 oh += scale_y_d
4617 if change in [2, 3]:
4618 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4619 ow -= scale_x_d
4620 assert ow > 0 # Should have been caught in agResize
4621 else:
4622 ow += scale_x_d
4623
Matthew Haddon848efb42021-09-09 12:30:53 +01004624 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004625 output_dims = [
4626 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004627 oh,
4628 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004629 input.shape[0],
4630 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004631 elif error_name == ErrorIf.BatchMismatch:
4632 output_dims = [
4633 input.shape[0] + rng.integers(1, 10),
4634 oh,
4635 ow,
4636 input.shape[3],
4637 ]
4638 elif error_name == ErrorIf.ChannelMismatch:
4639 output_dims = [
4640 input.shape[0],
4641 oh,
4642 ow,
4643 input.shape[3] + rng.integers(1, 10),
4644 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004645 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004646 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004647
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004648 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004649
4650 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004651 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004652 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004653
4654 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004655 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004656 if error_name == ErrorIf.ConvOutputShapeMismatch:
4657 choices = [1, 2, 3]
4658 change = rng.choice(choices)
4659 if change in [1, 3]:
4660 output_shape[1] = output_shape[1] + rng.choice(choices)
4661 if change in [2, 3]:
4662 output_shape[2] = output_shape[2] + rng.choice(choices)
4663
James Ward8b390432022-08-12 20:48:56 +01004664 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004665 # Pick some potentially correct output dtype if input type is incorrect
4666 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004667 else:
James Ward8b390432022-08-12 20:48:56 +01004668 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004669
4670 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004671 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004672 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004673 else:
4674 excludes = [out_dtype]
4675 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004676 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004677
Kevin Cheng550ccc52021-03-03 11:21:43 -08004678 return ser.addOutput(output_shape, out_dtype)