blob: 3014c816951442f8ba40f2361572fcd7a6207c71 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Luke Huttona4e48ca2023-02-22 11:53:48 +000017from generator.tosa_utils import get_rank_mismatch_shape
Jeremy Johnson05c711e2022-12-12 18:00:41 +000018from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010019from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010021from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010027 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000028 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010029 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010030 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010031 TOSA_8K_LEVEL_MAX_KERNEL = 8192
32 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010033
Eric Kunzee5e26762020-10-13 16:11:07 -070034 def __init__(self, args):
35 self.args = args
36 self.basePath = args.output_dir
37 self.random_seed = args.random_seed
38 self.ser = None
39 self.rng = np.random.default_rng(self.random_seed)
40 self.createDynamicOpLists()
41 self.initOpListDefaults()
42 self.quantGen = TosaQuantGen()
43 # Force makeShape to do a specific starting shape
44 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010045 # Work out floating point range
46 self.random_fp_low = min(args.tensor_fp_value_range)
47 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070048
49 def createSerializer(self, opName, testPath):
50 self.testPath = os.path.join(opName, testPath)
51
52 fullPath = os.path.join(self.basePath, self.testPath)
53 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010054 # Embed const data in the flatbuffer
55 constMode = ts.ConstMode.EMBED
56 if self.args.dump_consts:
57 constMode = ts.ConstMode.EMBED_DUMP
58 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
60 def getSerializer(self):
61 return self.ser
62
63 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080064 with open(
65 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
66 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070067 fd.write(self.ser.serialize())
68
Kevin Cheng550ccc52021-03-03 11:21:43 -080069 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
70 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070071
Matthew Haddon74567092021-07-16 15:38:20 +010072 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000073 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010074 seed = self.random_seed + 1
75 self.rng = np.random.default_rng(seed)
76
Eric Kunzee5e26762020-10-13 16:11:07 -070077 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070078 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070079 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070080 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070081 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070082 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070083 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010084 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
85 elif dtype == DType.UINT8:
86 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070087 elif dtype == DType.INT16:
88 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010089 elif dtype == DType.UINT16:
90 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Won Jeona21b2e82023-08-10 10:33:01 +000091 elif (
92 dtype == DType.INT32 or dtype == DType.SHAPE
93 ): # restricting too large value for SHAPE
Kevin Cheng550ccc52021-03-03 11:21:43 -080094 return np.int32(
95 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
96 )
Eric Kunzee5e26762020-10-13 16:11:07 -070097 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080098 return np.int64(
99 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
100 )
James Ward8b390432022-08-12 20:48:56 +0100101 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100102 return np.float16(
103 self.rng.uniform(
104 low=self.random_fp_low, high=self.random_fp_high, size=shape
105 )
106 )
James Ward24dbc422022-10-19 12:20:31 +0100107 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100108 f32_tensor = np.float32(
109 self.rng.uniform(
110 low=self.random_fp_low, high=self.random_fp_high, size=shape
111 )
112 )
James Ward24dbc422022-10-19 12:20:31 +0100113 # Floor the last 16 bits of each f32 value
114 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100115 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100116 return np.float32(
117 self.rng.uniform(
118 low=self.random_fp_low, high=self.random_fp_high, size=shape
119 )
120 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700121 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800122 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700123
Kevin Cheng989cb052021-04-28 16:29:44 -0700124 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700125 placeholders = []
126
Kevin Cheng989cb052021-04-28 16:29:44 -0700127 assert len(shape_list) == len(dtype_list)
128
129 for idx, shape in enumerate(shape_list):
130 arr = self.getRandTensor(shape, dtype_list[idx])
131 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132
133 return placeholders
134
Kevin Cheng989cb052021-04-28 16:29:44 -0700135 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 consts = []
137
Kevin Cheng989cb052021-04-28 16:29:44 -0700138 assert len(shape_list) == len(dtype_list)
139
140 for idx, shape in enumerate(shape_list):
141 arr = self.getRandTensor(shape, dtype_list[idx])
142 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700143
144 return consts
145
146 def makeShape(self, rank):
147 if self.targetted_shape:
148 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800149 return np.int32(
150 self.rng.integers(
151 low=self.args.tensor_shape_range[0],
152 high=self.args.tensor_shape_range[1],
153 size=rank,
154 )
155 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700156
157 def setTargetShape(self, shape):
158 self.targetted_shape = shape
159
160 def randInt(self, low=0, high=256):
161 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
162
163 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100164 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100165 return np.float32(
166 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
167 )
James Ward8b390432022-08-12 20:48:56 +0100168 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100169 return np.float16(
170 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
171 )
James Ward24dbc422022-10-19 12:20:31 +0100172 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100173 rand_f32 = np.float32(
174 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
175 )
James Ward24dbc422022-10-19 12:20:31 +0100176 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 elif dtype == DType.BOOL:
178 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700179 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700180 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700181 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100183 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 elif dtype == DType.INT16:
185 low, high = (-32768, 32768)
Won Jeona21b2e82023-08-10 10:33:01 +0000186 elif (
187 dtype == DType.INT32 or dtype == DType.SHAPE
188 ): # restricting too large value for SHAPE
Kevin Cheng550ccc52021-03-03 11:21:43 -0800189 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700190 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800191 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 # Special size
193 return np.int64(self.rng.integers(low, high, size=1))[0]
194 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800195 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700196
197 return np.int32(self.rng.integers(low, high, size=1))[0]
198
199 def shapeStr(self, shape):
200
201 sStr = []
202 # Convert to strings
203 for i in shape:
204 sStr.append(str(i))
205
Kevin Cheng550ccc52021-03-03 11:21:43 -0800206 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100208 def typeStr(self, dtype):
209 if isinstance(dtype, list) or isinstance(dtype, tuple):
210 assert len(dtype) >= 2
211 strs = [self.typeStr(t) for t in dtype]
212 # Limit types to the first 2 as the 3rd is the accumulator
213 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100215 if dtype in DTYPE_ATTRIBUTES:
216 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700217 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100218 raise Exception(
219 "Unknown dtype, cannot convert to string: {}".format(dtype)
220 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700221
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100222 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100223 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100224 if dtype in DTYPE_ATTRIBUTES:
225 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700226 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100227 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
Luke Hutton57287132023-02-06 14:54:18 +0000229 def constrictBatchSize(self, shape):
230 # Limit the batch size unless an explicit target shape set
231 if self.args.max_batch_size and not self.args.target_shapes:
232 shape[0] = min(shape[0], self.args.max_batch_size)
233 return shape
234
James Ward30124a82023-02-02 14:56:33 +0000235 def makeDimension(self):
236 return self.randInt(
237 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
238 )
239
Eric Kunzee5e26762020-10-13 16:11:07 -0700240 # Argument generators
241 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
242 # Where the string descriptor is used to generate the test name and
243 # The build_fcn_arg_list is expanded and passed to the operator test
244 # build function
245
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100246 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
247 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
248
Matthew Haddon848efb42021-09-09 12:30:53 +0100249 # build_placeholder returns an int, ABS/other ops does not
250 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000251 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100252 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000253 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000254 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100255 return result_tens
256
257 # Ensure new output type has correct qinfo
258 if error_name == ErrorIf.WrongOutputType:
259 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000260 qinfo = [
261 TosaQuantGen.getZeroPoint(self, a.dtype),
262 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
263 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100264
265 # Invalidate Input/Output list for error if checks.
266 input_list = [a.name]
267 output_list = [result_tens.name]
268 pCount, cCount = op["operands"]
269 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000270 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
271 self, error_name, input_list, output_list
272 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100273
Les Bell729b0352021-11-24 10:28:21 +0000274 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100275 self.ser,
276 validator_fcns,
277 error_name,
278 op=op,
279 input_dtype=a.dtype,
280 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000281 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000282 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100283 input_list=input_list,
284 output_list=output_list,
285 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000286 ):
287 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100288
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000289 attr = None
290 if op["op"] == Op.NEGATE:
291 attr = ts.TosaSerializerAttribute()
292 attr.NegateAttribute(qinfo[0], qinfo[1])
293
294 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700295 return result_tens
296
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100297 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000298 result_tens = OutputShaper.binaryBroadcastOp(
299 self.ser, self.rng, a, b, error_name
300 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100301
302 # Invalidate Input/Output list for error if checks.
303 input_list = [a.name, b.name]
304 output_list = [result_tens.name]
305 pCount, cCount = op["operands"]
306 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000307 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
308 self, error_name, input_list, output_list
309 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100310
Les Bell729b0352021-11-24 10:28:21 +0000311 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100312 self.ser,
313 validator_fcns,
314 error_name,
315 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000316 input1=a,
317 input2=b,
318 input_dtype=a.dtype,
319 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000320 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100321 input_list=input_list,
322 output_list=output_list,
323 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000324 ):
325 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100326
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000327 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700328 return result_tens
329
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100330 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700331 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000332 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700333 return result_tens
334
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000335 def build_arithmetic_right_shift(
336 self, op, a, b, round, validator_fcns=None, error_name=None
337 ):
338 result_tens = OutputShaper.binaryBroadcastOp(
339 self.ser, self.rng, a, b, error_name
340 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100341
342 # Invalidate Input/Output list for error if checks.
343 input_list = [a.name, b.name]
344 output_list = [result_tens.name]
345 pCount, cCount = op["operands"]
346 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000347 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
348 self, error_name, input_list, output_list
349 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100350
Les Bell729b0352021-11-24 10:28:21 +0000351 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100352 self.ser,
353 validator_fcns,
354 error_name,
355 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000356 input1=a,
357 input2=b,
358 input_dtype=a.dtype,
359 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000360 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100361 input_list=input_list,
362 output_list=output_list,
363 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000364 ):
365 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800366
367 attr = ts.TosaSerializerAttribute()
368 attr.ArithmeticRightShiftAttribute(round)
369
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000370 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800371 return result_tens
372
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100373 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000374 result_tens = OutputShaper.binaryBroadcastOp(
375 self.ser, self.rng, a, b, error_name
376 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700377
378 # Special for multiply:
379 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100380 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700381 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100382 if error_name == ErrorIf.WrongOutputType:
383 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
384 outputDType = self.rng.choice(all_dtypes)
385 result_tens.setDtype(outputDType)
386
387 # Invalidate Input/Output list for error if checks.
388 input_list = [a.name, b.name]
389 output_list = [result_tens.name]
390 pCount, cCount = op["operands"]
391 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000392 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
393 self, error_name, input_list, output_list
394 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100395
Les Bell729b0352021-11-24 10:28:21 +0000396 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100397 self.ser,
398 validator_fcns,
399 error_name,
400 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000401 input1=a,
402 input2=b,
403 input_dtype=a.dtype,
404 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000405 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100406 input_list=input_list,
407 output_list=output_list,
408 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000409 ):
410 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700411
Kevin Chengaee1fac2020-11-11 13:54:06 -0800412 attr = ts.TosaSerializerAttribute()
413 attr.MulAttribute(shift)
414
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000415 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700416 return result_tens
417
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100418 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
419 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700420
Kevin Chengfe392ce2021-10-18 21:51:55 +0000421 attr = ts.TosaSerializerAttribute()
422 attr.TableAttribute(table)
423
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100424 # Invalidate Input/Output list for error if checks.
425 input_list = [a.name]
426 output_list = [result_tens.name]
427 pCount, cCount = op["operands"]
428 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000429 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
430 self, error_name, input_list, output_list
431 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100432
Les Bell729b0352021-11-24 10:28:21 +0000433 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100434 self.ser,
435 validator_fcns,
436 error_name,
437 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000438 input_shape=a.shape,
439 input_dtype=a.dtype,
440 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000441 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100442 input_list=input_list,
443 output_list=output_list,
444 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000445 ):
446 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100447
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000448 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700449
450 return result_tens
451
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
453 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
454
455 # Invalidate Input/Output list for error if checks.
456 input_list = [cond.name, a.name, b.name]
457 output_list = [result_tens.name]
458 pCount, cCount = op["operands"]
459 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000460 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
461 self, error_name, input_list, output_list
462 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100463
Les Bell729b0352021-11-24 10:28:21 +0000464 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465 self.ser,
466 validator_fcns,
467 error_name,
468 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000469 input1=cond,
470 input2=a,
471 input3=b,
472 input_shape=a.shape,
473 input_dtype=a.dtype,
474 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000475 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100476 input_list=input_list,
477 output_list=output_list,
478 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000479 ):
480 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100481
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000482 self.ser.addOperator(
483 op["op"],
484 input_list,
485 output_list,
486 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700487 return result_tens
488
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100489 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000490 result_tens = OutputShaper.binaryComparisonOp(
491 self.ser, self.rng, a, b, error_name
492 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100493
494 # Invalidate Input/Output list for error if checks.
495 input_list = [a.name, b.name]
496 output_list = [result_tens.name]
497 pCount, cCount = op["operands"]
498 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000499 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
500 self, error_name, input_list, output_list
501 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100502
Les Bell729b0352021-11-24 10:28:21 +0000503 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100504 self.ser,
505 validator_fcns,
506 error_name,
507 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000508 input1=a,
509 input2=b,
510 input_shape=a.shape,
511 input_dtype=a.dtype,
512 output_shape=result_tens.shape,
513 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000514 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100515 input_list=input_list,
516 output_list=output_list,
517 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000518 ):
519 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100520
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000521 self.ser.addOperator(
522 op["op"],
523 input_list,
524 output_list,
525 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700526 return result_tens
527
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100528 def build_argmax(self, op, a, axis, validator_fcns, error_name):
529 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
530
531 # Invalidate Input/Output list for error if checks.
532 input_list = [a.name]
533 output_list = [result_tens.name]
534 pCount, cCount = op["operands"]
535 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
537 self, error_name, input_list, output_list
538 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100539
Les Bell729b0352021-11-24 10:28:21 +0000540 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100541 self.ser,
542 validator_fcns,
543 error_name,
544 op=op,
545 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000546 input_shape=a.shape,
547 input_dtype=a.dtype,
548 output_shape=result_tens.shape,
549 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000550 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100551 input_list=input_list,
552 output_list=output_list,
553 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000554 ):
555 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700556
557 attr = ts.TosaSerializerAttribute()
558 attr.AxisAttribute(axis)
559
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000560 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700561 return result_tens
562
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000563 def build_pool2d(
564 self,
565 op,
566 input,
James Ward8b390432022-08-12 20:48:56 +0100567 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000568 stride,
569 pad,
570 kernel,
571 validator_fcns=None,
572 error_name=None,
573 qinfo=None,
574 ):
575 result_tens = OutputShaper.pool2dOp(
576 self.ser, self.rng, input, kernel, stride, pad, error_name
577 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100578
579 # Ensure new output type has correct qinfo
580 if error_name == ErrorIf.WrongInputType:
581 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000582 qinfo = [
583 TosaQuantGen.getZeroPoint(self, input.dtype),
584 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
585 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100586
587 # Invalidate Input/Output list for error if checks.
588 input_list = [input.name]
589 output_list = [result_tens.name]
590 pCount, cCount = op["operands"]
591 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000592 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
593 self, error_name, input_list, output_list
594 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100595
Les Bell729b0352021-11-24 10:28:21 +0000596 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100597 self.ser,
598 validator_fcns,
599 error_name,
600 op=op,
601 input_shape=input.shape,
602 input_dtype=input.dtype,
603 output_shape=result_tens.shape,
604 output_dtype=result_tens.dtype,
605 kernel=kernel,
606 stride=stride,
607 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000608 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000609 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100610 input_list=input_list,
611 output_list=output_list,
612 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000613 ):
614 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700615
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000616 if qinfo is None:
617 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700618
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000619 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100620 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000621
622 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700623 return result_tens
624
James Ward8b390432022-08-12 20:48:56 +0100625 def build_maxpool2d(
626 self,
627 op,
628 input,
629 stride,
630 pad,
631 kernel,
632 validator_fcns=None,
633 error_name=None,
634 qinfo=None,
635 ):
636 # Same as build_pool2d but manually sets accum_dtype value
637 # (maxpool has no accum_dtype)
638 return self.build_pool2d(
639 op,
640 input,
641 DType.UNKNOWN,
642 stride,
643 pad,
644 kernel,
645 validator_fcns,
646 error_name,
647 qinfo,
648 )
649
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000650 def build_conv2d(
651 self,
652 op,
653 ifm,
654 filter,
655 bias,
James Ward8b390432022-08-12 20:48:56 +0100656 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000657 strides,
658 padding,
659 dilations,
660 validator_fcns=None,
661 error_name=None,
662 qinfo=None,
663 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800664 assert len(padding) == 4
665 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100666 self.ser,
667 self.rng,
668 ifm,
669 filter,
670 accum_dtype,
671 strides,
672 padding,
673 dilations,
674 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000675 )
676
677 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000678 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
679 DType.INT8,
680 DType.UINT8,
681 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000682 qinfo = [
683 TosaQuantGen.getZeroPoint(self, ifm.dtype),
684 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
685 ]
Les Bell0e027d42021-11-09 14:42:14 +0000686
687 # Invalidate Input/Output list for error_if checks.
688 input_list = [ifm.name, filter.name, bias.name]
689 output_list = [result_tens.name]
690 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000691 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
692 self, error_name, input_list, output_list
693 )
Les Bell0e027d42021-11-09 14:42:14 +0000694
Les Bell729b0352021-11-24 10:28:21 +0000695 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000696 self.ser,
697 validator_fcns,
698 error_name,
699 op=op,
700 input_dtype=ifm.dtype,
701 weight_dtype=filter.dtype,
702 output_dtype=result_tens.dtype,
703 qinfo=qinfo,
704 input_list=input_list,
705 num_operands=num_operands,
706 output_list=output_list,
707 pad=padding,
708 stride=strides,
709 dilation=dilations,
710 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100711 weight_shape=filter.shape,
712 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000713 ):
714 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700715
716 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000717 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700718
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000719 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700720 return result_tens
721
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 def build_conv3d(
723 self,
724 op,
725 ifm,
726 filter,
727 bias,
James Ward8b390432022-08-12 20:48:56 +0100728 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000729 strides,
730 padding,
731 dilations,
732 validator_fcns=None,
733 error_name=None,
734 qinfo=None,
735 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700736 assert len(padding) == 6
737 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100738 self.ser,
739 self.rng,
740 ifm,
741 filter,
742 accum_dtype,
743 strides,
744 padding,
745 dilations,
746 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000747 )
748
749 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000750 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
751 DType.INT8,
752 DType.UINT8,
753 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000754 qinfo = [
755 TosaQuantGen.getZeroPoint(self, ifm.dtype),
756 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
757 ]
Les Bell0e027d42021-11-09 14:42:14 +0000758
759 # Invalidate Input/Output list for error_if checks.
760 input_list = [ifm.name, filter.name, bias.name]
761 output_list = [result_tens.name]
762 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000763 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
764 self, error_name, input_list, output_list
765 )
Les Bell0e027d42021-11-09 14:42:14 +0000766
Les Bell729b0352021-11-24 10:28:21 +0000767 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000768 self.ser,
769 validator_fcns,
770 error_name,
771 op=op,
772 input_dtype=ifm.dtype,
773 weight_dtype=filter.dtype,
774 output_dtype=result_tens.dtype,
775 qinfo=qinfo,
776 input_list=input_list,
777 num_operands=num_operands,
778 output_list=output_list,
779 pad=padding,
780 stride=strides,
781 dilation=dilations,
782 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100783 weight_shape=filter.shape,
784 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000785 ):
786 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700787
788 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000789 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700790
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000791 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700792 return result_tens
793
Kevin Cheng550ccc52021-03-03 11:21:43 -0800794 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000795 self,
796 op,
797 ifm,
798 filter,
799 bias,
James Ward8b390432022-08-12 20:48:56 +0100800 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700802 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 output_shape,
804 validator_fcns=None,
805 error_name=None,
806 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800807 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700808 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000809 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100810 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000811 )
Les Bell0e027d42021-11-09 14:42:14 +0000812
813 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000814 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
815 DType.INT8,
816 DType.UINT8,
817 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000818 qinfo = [
819 TosaQuantGen.getZeroPoint(self, ifm.dtype),
820 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
821 ]
Les Bell0e027d42021-11-09 14:42:14 +0000822
823 # Invalidate Input/Output list for error_if checks.
824 input_list = [ifm.name, filter.name, bias.name]
825 output_list = [result_tens.name]
826 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000827 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
828 self, error_name, input_list, output_list
829 )
Les Bell0e027d42021-11-09 14:42:14 +0000830
Les Bell729b0352021-11-24 10:28:21 +0000831 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000832 self.ser,
833 validator_fcns,
834 error_name,
835 op=op,
836 input_dtype=ifm.dtype,
837 weight_dtype=filter.dtype,
838 output_dtype=result_tens.dtype,
839 qinfo=qinfo,
840 input_list=input_list,
841 num_operands=num_operands,
842 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700843 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000844 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000845 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100846 weight_shape=filter.shape,
847 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000848 ):
849 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700850
851 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000852 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700853
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000854 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700855 return result_tens
856
Kevin Cheng550ccc52021-03-03 11:21:43 -0800857 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000858 self,
859 op,
860 ifm,
861 filter,
862 bias,
James Ward8b390432022-08-12 20:48:56 +0100863 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000864 strides,
865 padding,
866 dilations,
867 validator_fcns=None,
868 error_name=None,
869 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800870 ):
871 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100872 self.ser,
873 self.rng,
874 ifm,
875 filter,
876 accum_dtype,
877 strides,
878 padding,
879 dilations,
880 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000881 )
882
883 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000884 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
885 DType.INT8,
886 DType.UINT8,
887 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000888 qinfo = [
889 TosaQuantGen.getZeroPoint(self, ifm.dtype),
890 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
891 ]
Les Bell0e027d42021-11-09 14:42:14 +0000892
893 # Invalidate Input/Output list for error_if checks.
894 input_list = [ifm.name, filter.name, bias.name]
895 output_list = [result_tens.name]
896 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000897 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
898 self, error_name, input_list, output_list
899 )
Les Bell0e027d42021-11-09 14:42:14 +0000900
Les Bell729b0352021-11-24 10:28:21 +0000901 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000902 self.ser,
903 validator_fcns,
904 error_name,
905 op=op,
906 input_dtype=ifm.dtype,
907 weight_dtype=filter.dtype,
908 output_dtype=result_tens.dtype,
909 qinfo=qinfo,
910 input_list=input_list,
911 num_operands=num_operands,
912 output_list=output_list,
913 pad=padding,
914 stride=strides,
915 dilation=dilations,
916 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100917 weight_shape=filter.shape,
918 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000919 ):
920 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700921
922 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000923 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700924
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000925 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700926 return result_tens
927
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000928 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100929 self,
930 op,
931 ifm,
932 filter,
933 bias,
934 accum_dtype,
935 validator_fcns=None,
936 error_name=None,
937 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000938 ):
939 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100940 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000941 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100942
943 # Invalidate Input/Output list for error if checks.
944 input_list = [ifm.name, filter.name, bias.name]
945 output_list = [result_tens.name]
946 pCount, cCount = op["operands"]
947 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000948 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
949 self, error_name, input_list, output_list
950 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100951
Les Bell729b0352021-11-24 10:28:21 +0000952 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100953 self.ser,
954 validator_fcns,
955 error_name,
956 op=op,
957 input_shape=ifm.shape,
958 input_dtype=ifm.dtype,
959 weight_dtype=filter.dtype,
960 output_shape=result_tens.shape,
961 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000962 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000963 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100964 input_list=input_list,
965 output_list=output_list,
966 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100967 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000968 ):
969 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700970
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000971 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000972 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000973
974 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700975 return result_tens
976
James Ward8b390432022-08-12 20:48:56 +0100977 def build_matmul(
978 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
979 ):
980 result_tens = OutputShaper.matmulOp(
981 self.ser, self.rng, a, b, accum_dtype, error_name
982 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100983
984 # Invalidate Input/Output list for error if checks.
985 input_list = [a.name, b.name]
986 output_list = [result_tens.name]
987 pCount, cCount = op["operands"]
988 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000989 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
990 self, error_name, input_list, output_list
991 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100992
Les Bell729b0352021-11-24 10:28:21 +0000993 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100994 self.ser,
995 validator_fcns,
996 error_name,
997 op=op,
998 input_shape=a.shape,
999 input_dtype=a.dtype,
1000 input2_shape=b.shape,
1001 input2_dtype=b.dtype,
1002 output_shape=result_tens.shape,
1003 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001004 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001005 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001006 input_list=input_list,
1007 output_list=output_list,
1008 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001009 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001010 ):
1011 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001012
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001013 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001014 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001015
1016 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001017 return result_tens
1018
Matthew Haddond6ce7252021-09-29 15:35:44 +01001019 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1020 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1021
1022 # Invalidate Input/Output list for error if checks.
1023 input_list = [a.name]
1024 output_list = [result_tens.name]
1025 pCount, cCount = op["operands"]
1026 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001027 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1028 self, error_name, input_list, output_list
1029 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001030
Les Bell729b0352021-11-24 10:28:21 +00001031 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001032 self.ser,
1033 validator_fcns,
1034 error_name,
1035 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001036 axis=axis,
1037 input_shape=a.shape,
1038 output_shape=result_tens.shape,
1039 input_dtype=a.dtype,
1040 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001041 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001042 input_list=input_list,
1043 output_list=output_list,
1044 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001045 ):
1046 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001047
1048 attr = ts.TosaSerializerAttribute()
1049 attr.AxisAttribute(axis)
1050
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001051 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001052 return result_tens
1053
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001054 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1055 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001056
Jeremy Johnson18e26662021-07-22 16:15:29 +01001057 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001058
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001059 if error_name == ErrorIf.MaxSmallerMin:
1060 # Make sure the numbers are different to invoke this error
1061 while v[0] == v[1]:
1062 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1063 max_val = min(v)
1064 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001065 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001066 max_val = max(v)
1067 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001068
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001069 # Invalidate Input/Output list for error if checks.
1070 input_list = [a.name]
1071 output_list = [result_tens.name]
1072 pCount, cCount = op["operands"]
1073 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001074 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1075 self, error_name, input_list, output_list
1076 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001077
Les Bell729b0352021-11-24 10:28:21 +00001078 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001079 self.ser,
1080 validator_fcns,
1081 error_name,
1082 op=op,
1083 max_val=max_val,
1084 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001085 input_shape=a.shape,
1086 output_shape=result_tens.shape,
1087 input_dtype=a.dtype,
1088 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001089 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001090 input_list=input_list,
1091 output_list=output_list,
1092 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001093 ):
1094 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001095
1096 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001097 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1098 if a.dtype == DType.FP16:
1099 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1100 min_val = min_val.astype(np.float32)
1101 max_val = max_val.astype(np.float32)
1102
1103 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001104 else:
James Ward34071252022-12-07 15:48:47 +00001105 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001106
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001107 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001108 return result_tens
1109
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001110 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1111 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001112 attr = ts.TosaSerializerAttribute()
1113
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001114 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001115
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001116 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001117 return result_tens
1118
1119 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001120 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1121 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001122
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001123 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001124 return result_tens
1125
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001126 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1127 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1128
1129 # Invalidate Input/Output list for error if checks.
1130 input_list = [a.name]
1131 output_list = [result_tens.name]
1132 pCount, cCount = op["operands"]
1133 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001134 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1135 self, error_name, input_list, output_list
1136 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001137
Les Bell729b0352021-11-24 10:28:21 +00001138 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001139 self.ser,
1140 validator_fcns,
1141 error_name,
1142 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 input_shape=a.shape,
1144 output_shape=result_tens.shape,
1145 input_dtype=a.dtype,
1146 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001147 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001148 input_list=input_list,
1149 output_list=output_list,
1150 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001151 ):
1152 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001153
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001154 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001155 return result_tens
1156
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001157 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1158 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1159
1160 # Invalidate Input/Output list for error if checks.
1161 input_list = [a.name]
1162 output_list = [result_tens.name]
1163 pCount, cCount = op["operands"]
1164 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001165 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1166 self, error_name, input_list, output_list
1167 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001168
Les Bell729b0352021-11-24 10:28:21 +00001169 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001170 self.ser,
1171 validator_fcns,
1172 error_name,
1173 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001174 input_shape=a.shape,
1175 output_shape=result_tens.shape,
1176 input_dtype=a.dtype,
1177 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001178 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001179 input_list=input_list,
1180 output_list=output_list,
1181 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001182 ):
1183 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001184
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001185 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001186 return result_tens
1187
Won Jeon78155c62023-06-10 00:20:04 +00001188 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1189 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1190
1191 # Invalidate Input/Output list for error if checks.
1192 input_list = [a.name]
1193 output_list = [result_tens.name]
1194 pCount, cCount = op["operands"]
1195 num_operands = pCount + cCount
1196 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1197 self, error_name, input_list, output_list
1198 )
1199
1200 if not TosaErrorValidator.evValidateErrorIfs(
1201 self.ser,
1202 validator_fcns,
1203 error_name,
1204 op=op,
1205 input_shape=a.shape,
1206 output_shape=result_tens.shape,
1207 input_dtype=a.dtype,
1208 output_dtype=result_tens.dtype,
1209 result_tensors=[result_tens],
1210 input_list=input_list,
1211 output_list=output_list,
1212 num_operands=num_operands,
1213 ):
1214 return None
1215
1216 self.ser.addOperator(op["op"], input_list, output_list)
1217 return result_tens
1218
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001219 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1220 if error_name != ErrorIf.WrongInputType:
1221 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001222
1223 # To store variable length list of input tensors we need to store axis along with it
1224 axis = a[-1]
1225 a = a[:-1]
1226
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001227 result_tens = OutputShaper.concatOp(
1228 self.ser, self.rng, axis, *a, error_name=error_name
1229 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001230
Matthew Haddon818ab902021-07-27 09:12:49 +01001231 input_tensor_names = []
1232 for tensor in a:
1233 input_tensor_names.append(tensor.name)
1234
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001235 # Invalidate Input/Output list for error if checks.
1236 input_list = input_tensor_names
1237 output_list = [result_tens.name]
1238 pCount, cCount = op["operands"]
1239 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001240 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1241 self, error_name, input_list, output_list
1242 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001243
Les Bell729b0352021-11-24 10:28:21 +00001244 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001245 self.ser,
1246 validator_fcns,
1247 error_name,
1248 op=op,
1249 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001250 input_shape=a[0].shape,
1251 output_shape=result_tens.shape,
1252 input_dtype=a[0].dtype,
1253 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001254 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001255 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001256 input_list=input_list,
1257 output_list=output_list,
1258 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001259 ):
1260 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001261
1262 attr = ts.TosaSerializerAttribute()
1263 attr.AxisAttribute(axis)
1264
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001265 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001266 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001267
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001268 def build_pad(
1269 self,
1270 op,
1271 a,
1272 padding,
1273 pad_const_int,
1274 pad_const_float,
1275 validator_fcns=None,
1276 error_name=None,
1277 qinfo=None,
1278 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001279 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001280
Kevin Chengfe392ce2021-10-18 21:51:55 +00001281 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001282 attr.PadAttribute(
1283 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1284 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001285
Matthew Haddone807aae2021-10-11 18:12:58 +01001286 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001287 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001288 output_list = [result_tens.name]
1289 pCount, cCount = op["operands"]
1290 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001291 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1292 self, error_name, input_list, output_list
1293 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001294
Les Bell729b0352021-11-24 10:28:21 +00001295 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001296 self.ser,
1297 validator_fcns,
1298 error_name,
1299 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001300 input_shape=a.shape,
1301 output_shape=result_tens.shape,
1302 input_dtype=a.dtype,
1303 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001304 pad=padding,
1305 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001306 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001307 input_list=input_list,
1308 output_list=output_list,
1309 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001310 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001311 ):
1312 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001313
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001314 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001315 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001316
Won Jeona21b2e82023-08-10 10:33:01 +00001317 def build_dim(
1318 self,
1319 op,
1320 a,
1321 axis,
1322 validator_fcns=None,
1323 error_name=None,
1324 qinfo=None,
1325 ):
1326 result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
1327
1328 # Invalidate Input/Output list for error if checks.
1329 input_list = [a.name]
1330 output_list = [result_tens.name]
1331 pCount, cCount = op["operands"]
1332 num_operands = pCount + cCount
1333 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1334 self, error_name, input_list, output_list
1335 )
1336
1337 if not TosaErrorValidator.evValidateErrorIfs(
1338 self.ser,
1339 validator_fcns,
1340 error_name,
1341 op=op,
1342 axis=axis,
1343 input_shape=a.shape,
1344 input_dtype=a.dtype,
1345 output_shape=result_tens.shape,
1346 output_dtype=result_tens.dtype,
1347 result_tensors=[result_tens],
1348 input_list=input_list,
1349 output_list=output_list,
1350 num_operands=num_operands,
1351 ):
1352 return None
1353
1354 attr = ts.TosaSerializerAttribute()
1355 attr.AxisAttribute(axis)
1356
1357 self.ser.addOperator(op["op"], input_list, output_list, attr)
1358 return result_tens
1359
Matthew Haddone807aae2021-10-11 18:12:58 +01001360 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001361 result_tens = OutputShaper.reshapeOp(
1362 self.ser, self.rng, a, newShape, error_name
1363 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001364
1365 # Invalidate Input/Output list for error if checks.
1366 input_list = [a.name]
1367 output_list = [result_tens.name]
1368 pCount, cCount = op["operands"]
1369 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001370 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1371 self, error_name, input_list, output_list
1372 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001373
Les Bell729b0352021-11-24 10:28:21 +00001374 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001375 self.ser,
1376 validator_fcns,
1377 error_name,
1378 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001379 input_shape=a.shape,
1380 output_shape=result_tens.shape,
1381 input_dtype=a.dtype,
1382 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001383 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001384 input_list=input_list,
1385 output_list=output_list,
1386 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001387 ):
1388 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001389
1390 attr = ts.TosaSerializerAttribute()
1391 attr.ReshapeAttribute(newShape)
1392
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001393 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001394 return result_tens
1395
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1397 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1398
1399 # Invalidate Input/Output list for error if checks.
1400 input_list = [a.name]
1401 output_list = [result_tens.name]
1402 pCount, cCount = op["operands"]
1403 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1405 self, error_name, input_list, output_list
1406 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001407
Les Bell729b0352021-11-24 10:28:21 +00001408 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001409 self.ser,
1410 validator_fcns,
1411 error_name,
1412 op=op,
1413 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001414 input_shape=a.shape,
1415 output_shape=result_tens.shape,
1416 input_dtype=a.dtype,
1417 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001418 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001419 input_list=input_list,
1420 output_list=output_list,
1421 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001422 ):
1423 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001424
1425 attr = ts.TosaSerializerAttribute()
1426 attr.AxisAttribute(axis)
1427
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001428 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001429 return result_tens
1430
Matthew Haddone807aae2021-10-11 18:12:58 +01001431 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1432 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001433
Kevin Chengfe392ce2021-10-18 21:51:55 +00001434 attr = ts.TosaSerializerAttribute()
1435 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001436
Matthew Haddone807aae2021-10-11 18:12:58 +01001437 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001438 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001439 output_list = [result_tens.name]
1440 pCount, cCount = op["operands"]
1441 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001442 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1443 self, error_name, input_list, output_list
1444 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001445
Les Bell729b0352021-11-24 10:28:21 +00001446 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001447 self.ser,
1448 validator_fcns,
1449 error_name,
1450 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001451 input_shape=a.shape,
1452 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001453 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 input_dtype=a.dtype,
1455 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001456 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001457 input_list=input_list,
1458 output_list=output_list,
1459 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001460 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001461 ):
1462 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001463
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001464 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001465 return result_tens
1466
Matthew Haddone807aae2021-10-11 18:12:58 +01001467 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001468 result_tens = OutputShaper.sliceOp(
1469 self.ser, self.rng, a, start, size, error_name
1470 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001471
1472 # Invalidate Input/Output list for error if checks.
1473 input_list = [a.name]
1474 output_list = [result_tens.name]
1475 pCount, cCount = op["operands"]
1476 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001477 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1478 self, error_name, input_list, output_list
1479 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001480
Les Bell729b0352021-11-24 10:28:21 +00001481 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001482 self.ser,
1483 validator_fcns,
1484 error_name,
1485 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001486 input_shape=a.shape,
1487 output_shape=result_tens.shape,
1488 input_dtype=a.dtype,
1489 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001490 start=start,
1491 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001492 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001493 input_list=input_list,
1494 output_list=output_list,
1495 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001496 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001497 ):
1498 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001499
1500 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001501 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001502
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001503 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001504 return result_tens
1505
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001506 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1507 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1508
1509 # Invalidate Input/Output list for error if checks.
1510 input_list = [a.name]
1511 output_list = [result_tens.name]
1512 pCount, cCount = op["operands"]
1513 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001514 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1515 self, error_name, input_list, output_list
1516 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001517
Les Bell729b0352021-11-24 10:28:21 +00001518 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001519 self.ser,
1520 validator_fcns,
1521 error_name,
1522 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001523 input_shape=a.shape,
1524 output_shape=result_tens.shape,
1525 input_dtype=a.dtype,
1526 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001527 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001528 input_list=input_list,
1529 output_list=output_list,
1530 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001531 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001532 ):
1533 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001534
1535 attr = ts.TosaSerializerAttribute()
1536 attr.TileAttribute(multiples)
1537
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001538 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001539 return result_tens
1540
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001541 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001542
1543 # Create a new indicies tensor
1544 # here with data that doesn't exceed the dimensions of the values tensor
1545
Kevin Cheng550ccc52021-03-03 11:21:43 -08001546 K = values.shape[1] # K
1547 W = self.randInt(
1548 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1549 ) # W
1550 indicies_arr = np.int32(
1551 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1552 ) # (N, W)
1553 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001554
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001555 result_tens = OutputShaper.gatherOp(
1556 self.ser, self.rng, values, indicies, error_name
1557 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001558
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001559 # Invalidate Input/Output list for error if checks.
1560 input_list = [values.name, indicies.name]
1561 output_list = [result_tens.name]
1562 pCount, cCount = op["operands"]
1563 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1565 self, error_name, input_list, output_list
1566 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001567
Les Bell729b0352021-11-24 10:28:21 +00001568 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001569 self.ser,
1570 validator_fcns,
1571 error_name,
1572 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001573 input_shape=values.shape,
1574 output_shape=result_tens.shape,
1575 input_dtype=values.dtype,
1576 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001577 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001578 input_list=input_list,
1579 output_list=output_list,
1580 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001581 ):
1582 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001583
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001584 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001585
1586 return result_tens
1587
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001588 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001589
1590 # Create a new indicies tensor
1591 # here with data that doesn't exceed the dimensions of the values_in tensor
1592
Kevin Cheng550ccc52021-03-03 11:21:43 -08001593 K = values_in.shape[1] # K
1594 W = input.shape[1] # W
1595 indicies_arr = np.int32(
1596 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1597 ) # (N, W)
1598 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001599
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001600 result_tens = OutputShaper.scatterOp(
1601 self.ser, self.rng, values_in, indicies, input, error_name
1602 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001603
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001604 # Invalidate Input/Output list for error if checks.
1605 input_list = [values_in.name, indicies.name, input.name]
1606 output_list = [result_tens.name]
1607 pCount, cCount = op["operands"]
1608 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001609 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1610 self, error_name, input_list, output_list
1611 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001612
Les Bell729b0352021-11-24 10:28:21 +00001613 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001614 self.ser,
1615 validator_fcns,
1616 error_name,
1617 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001618 input_shape=values_in.shape,
1619 output_shape=result_tens.shape,
1620 input_dtype=values_in.dtype,
1621 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001622 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001623 input_list=input_list,
1624 output_list=output_list,
1625 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001626 ):
1627 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001628
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001629 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001630
Kevin Cheng77d0f762020-11-24 10:26:32 -08001631 return result_tens
1632
Kevin Cheng550ccc52021-03-03 11:21:43 -08001633 def build_resize(
1634 self,
1635 op,
1636 input,
1637 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001638 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001639 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001640 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001641 input_dtype,
1642 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001643 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001644 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001645 ):
1646 result_tens = OutputShaper.resizeOp(
1647 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001648 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001649 input,
1650 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001651 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001652 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001653 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001654 input_dtype,
1655 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001656 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001657 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001658
Matthew Haddon848efb42021-09-09 12:30:53 +01001659 # Invalidate Input/Output list for error if checks.
1660 input_list = [input.name]
1661 output_list = [result_tens.name]
1662 pCount, cCount = op["operands"]
1663 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001664 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1665 self, error_name, input_list, output_list
1666 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001667
Les Bell729b0352021-11-24 10:28:21 +00001668 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001669 self.ser,
1670 validator_fcns,
1671 error_name,
1672 op=op,
1673 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001674 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001675 input_dtype=input_dtype,
1676 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001677 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001678 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001679 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001680 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001681 input_list=input_list,
1682 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001683 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001684 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001685 ):
1686 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001687
Eric Kunzee5e26762020-10-13 16:11:07 -07001688 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001689
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001690 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001691
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001693 return result_tens
1694
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001695 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1696 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1697 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001698 self.ser.addOperator(
1699 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1700 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001701 return result_tens
1702
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001703 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001704 self.ser.addOutputTensor(val)
1705 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001706
1707 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001708 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001709 result_tens = OutputShaper.typeConversionOp(
1710 self.ser, self.rng, val, out_dtype, error_name
1711 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001712
1713 # Invalidate Input/Output list for error if checks.
1714 input_list = [val.name]
1715 output_list = [result_tens.name]
1716 pCount, cCount = op["operands"]
1717 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001718 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1719 self, error_name, input_list, output_list
1720 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001721
Les Bell729b0352021-11-24 10:28:21 +00001722 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001723 self.ser,
1724 validator_fcns,
1725 error_name,
1726 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001727 input_shape=val.shape,
1728 output_shape=result_tens.shape,
1729 input_dtype=val.dtype,
1730 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001731 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001732 input_list=input_list,
1733 output_list=output_list,
1734 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001735 ):
1736 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001737
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001738 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001739 return result_tens
1740
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001741 def build_rescale(
1742 self,
1743 op,
1744 val,
1745 out_dtype,
1746 scale32,
1747 double_round,
1748 per_channel,
1749 validator_fcns,
1750 error_name,
1751 ):
1752 result_tens = OutputShaper.typeConversionOp(
1753 self.ser, self.rng, val, out_dtype, error_name
1754 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001755
1756 if per_channel:
1757 nc = val.shape[-1]
1758 else:
1759 nc = 1
1760
1761 in_type_width = self.typeWidth(val.dtype)
1762 out_type_width = self.typeWidth(out_dtype)
1763
Kevin Cheng3a478572021-01-22 17:21:02 -08001764 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001765 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001766 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001767 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001768 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001769 in_type_width += 1
1770 elif error_name in [
1771 ErrorIf.InputZeroPointNotZero,
1772 ErrorIf.U16InputZeroPointNotValid,
1773 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001774 input_zp = self.randInt(-128, 128)
1775 if input_zp == 0:
1776 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001777 in_type_width += 1
1778 elif val.dtype == DType.UINT16:
1779 # Must come after ErrorIf.U16InputZeroPointNotValid check
1780 input_zp = self.rng.choice([0, 32768])
1781 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001782 else:
1783 input_zp = 0
1784
Kevin Cheng3a478572021-01-22 17:21:02 -08001785 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001786 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001787 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001788 elif out_dtype == DType.UINT8:
1789 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001790 out_type_width += 1
1791 elif error_name in [
1792 ErrorIf.OutputZeroPointNotZero,
1793 ErrorIf.U16OutputZeroPointNotValid,
1794 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001795 output_zp = self.randInt(-128, 128)
1796 if output_zp == 0:
1797 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001798 out_type_width += 1
1799 elif out_dtype == DType.UINT16:
1800 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1801 output_zp = self.rng.choice([0, 32768])
1802 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001803 else:
1804 output_zp = 0
1805
1806 # Calculate scale based on:
1807 # scale = a *(2^output_width)/(2^input_width))
1808
1809 a = np.float32(self.rng.random(size=[nc]))
1810 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1811
1812 if scale32:
1813 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001814 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001815 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1816 else:
1817 # Cap the scaling at 2^15 - 1 for scale16
1818 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1819
Kevin Cheng550ccc52021-03-03 11:21:43 -08001820 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001821
1822 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1823 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001824 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1825 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001826
1827 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001828 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1829 scale_arr[i], scale32
1830 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001831 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1832 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001833
Kevin Cheng550ccc52021-03-03 11:21:43 -08001834 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001835 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001836 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001837 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001838 assert val.placeholderFilename
1839 values = np.load(
1840 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1841 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001842 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1843 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1844 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1845 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001846 if not np.all(np.array_equal(values, val_adj)):
1847 # Values changed so overwrite file with new values
1848 np.save(
1849 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1850 val_adj,
1851 False,
1852 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001853
Matthew Haddonc2025212021-10-08 21:21:05 +01001854 # Invalidate Input/Output list for error if checks.
1855 input_list = [val.name]
1856 output_list = [result_tens.name]
1857 pCount, cCount = op["operands"]
1858 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001859 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1860 self, error_name, input_list, output_list
1861 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001862
1863 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001864 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001865 self.ser,
1866 validator_fcns,
1867 error_name,
1868 op=op,
1869 input_dtype=val.dtype,
1870 output_dtype=out_dtype,
1871 input_shape=val.shape,
1872 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001873 scale32=scale32,
1874 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001875 input_list=input_list,
1876 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001877 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001878 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001879 ):
1880 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001881
Eric Kunzee5e26762020-10-13 16:11:07 -07001882 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001883 attr.RescaleAttribute(
1884 input_zp,
1885 output_zp,
1886 multiplier_arr,
1887 shift_arr,
1888 scale32,
1889 double_round,
1890 per_channel,
1891 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001892
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001893 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001894 return result_tens
1895
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001896 def _get_condition_tensor(self, op, cond, error_name):
1897 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1898 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1899 else:
1900 cond_type = DType.BOOL
1901 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1902 choice = self.rng.choice([1, 2])
1903 if choice == 1:
1904 cond_shape = [2]
1905 else:
1906 cond_shape = [1, 2]
1907 else:
1908 # Must be of size 1 (rank 0)
1909 cond_shape = []
1910 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1911 return cond_tens
1912
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001913 def build_cond_if_const(
1914 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1915 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001916 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001917 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001918 # and fill them with const nodes for the body.
1919
1920 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001921 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001922
1923 # Make then/else tensors
1924 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001925
1926 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001927 if error_name in [
1928 ErrorIf.CondIfOutputListThenGraphMismatch,
1929 ErrorIf.CondIfOutputListElseGraphMismatch,
1930 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001931 incorrect_shape = deepcopy(then_tens.shape)
1932 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001933 incorrect_shape[i] += (
1934 self.rng.choice([-3, -2, 2, 3])
1935 if incorrect_shape[i] > 3
1936 else self.rng.choice([1, 2, 4])
1937 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001938 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1939
Jeremy Johnson18e26662021-07-22 16:15:29 +01001940 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1941 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001942
1943 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001945
1946 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001947 then_block = "THEN_BLOCK"
1948 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001949 attr = ts.TosaSerializerAttribute()
1950 attr.CondIfAttribute(then_block, else_block)
1951
1952 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001953 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001954
Jerry Ge9e94af82022-10-27 09:57:00 -07001955 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001956 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001957 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1958 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1959 else:
1960 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001961 self.ser.addOutputTensor(then_tens)
1962
Jerry Ge9e94af82022-10-27 09:57:00 -07001963 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001964 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1965 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1966 else:
1967 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001968 self.ser.addOutputTensor(else_tens)
1969
Les Bell729b0352021-11-24 10:28:21 +00001970 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001971 self.ser,
1972 validator_fcns,
1973 error_name,
1974 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001975 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001976 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001977 ):
1978 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001979
Eric Kunzee5e26762020-10-13 16:11:07 -07001980 return result_tens
1981
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001982 def build_cond_if_binary(
1983 self, op, a, b, cond, validator_fcns=None, error_name=None
1984 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001985 # For cond_if with a binary op in the then/else blocks, take a and b and
1986 # alternately add or subtract them based on the condition
1987
1988 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001989 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001990
Kevin Cheng550ccc52021-03-03 11:21:43 -08001991 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001992
1993 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001994 then_block = "THEN_BLOCK"
1995 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001996 attr = ts.TosaSerializerAttribute()
1997 attr.CondIfAttribute(then_block, else_block)
1998
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001999 if error_name in [
2000 ErrorIf.CondIfInputListThenGraphMismatch,
2001 ErrorIf.CondIfInputListElseGraphMismatch,
2002 ErrorIf.CondIfOutputListElseGraphMismatch,
2003 ErrorIf.CondIfOutputListThenGraphMismatch,
2004 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002005 incorrect_shape = a.shape.copy()
2006 for i in range(len(incorrect_shape)):
2007 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2008 incorrect_block_input = deepcopy(a)
2009 incorrect_block_input.shape = incorrect_shape
2010
Eric Kunzee5e26762020-10-13 16:11:07 -07002011 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002012 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002013 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002014 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002015
James Ward24dbc422022-10-19 12:20:31 +01002016 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002017 then_op, else_op = Op.ADD, Op.SUB
2018 elif a.dtype in (DType.INT8, DType.INT16):
2019 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2020 else:
2021 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002022
Les Bell6040b4d2021-10-11 12:50:31 +01002023 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002024 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002025 if (
2026 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2027 and block == then_block
2028 ) or (
2029 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2030 and block == else_block
2031 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002032 self.ser.addInputTensor(incorrect_block_input)
2033 self.ser.addInputTensor(b)
2034 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002035 elif (
2036 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2037 and block == then_block
2038 ) or (
2039 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2040 and block == else_block
2041 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002042 self.ser.addInputTensor(a)
2043 self.ser.addInputTensor(b)
2044 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2045 else:
2046 self.ser.addInputTensor(a)
2047 self.ser.addInputTensor(b)
2048 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002049 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002050
Les Bell729b0352021-11-24 10:28:21 +00002051 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002052 self.ser,
2053 validator_fcns,
2054 error_name,
2055 op=op,
2056 a=a,
2057 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002058 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002059 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002060 ):
2061 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002062
Eric Kunzee5e26762020-10-13 16:11:07 -07002063 return result_tens
2064
Matthew Haddon630c17c2021-10-14 15:05:41 +01002065 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002066 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002067
Kevin Cheng550ccc52021-03-03 11:21:43 -08002068 cond_block = "COND_BLOCK"
2069 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002070
2071 attr = ts.TosaSerializerAttribute()
2072 attr.WhileLoopAttribute(cond_block, body_block)
2073
2074 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002075 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002076 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002077 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002078
2079 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002080 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2081 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002082 if error_name == ErrorIf.InputListOutputListMismatch:
2083 incorrect_acc = deepcopy(acc)
2084 for i in range(len(incorrect_acc.shape)):
2085 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2086 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2087 else:
2088 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002089
2090 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002091 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002092 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002093 [iter.name, a.name, acc.name],
2094 [iter_out.name, a_out.name, acc_out.name],
2095 attr,
2096 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002097 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002098
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002099 if error_name in [
2100 ErrorIf.InputListCondGraphMismatch,
2101 ErrorIf.InputListBodyGraphInputMismatch,
2102 ErrorIf.InputListBodyGraphOutputMismatch,
2103 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002104 incorrect_iter = deepcopy(iter)
2105 for i in range(len(incorrect_iter.shape)):
2106 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2107 if len(incorrect_iter.shape) == 0:
2108 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2109
2110 incorrect_acc = deepcopy(acc)
2111 for i in range(len(incorrect_acc.shape)):
2112 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2113
Eric Kunzee5e26762020-10-13 16:11:07 -07002114 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002115 self.ser.addBasicBlock(cond_block)
2116
Matthew Haddon630c17c2021-10-14 15:05:41 +01002117 if error_name == ErrorIf.InputListCondGraphMismatch:
2118 self.ser.addInputTensor(incorrect_iter)
2119 self.ser.addInputTensor(a)
2120 self.ser.addInputTensor(incorrect_acc)
2121 else:
2122 self.ser.addInputTensor(iter)
2123 self.ser.addInputTensor(a)
2124 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002125 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002126
2127 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002128 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002129 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002130 cond_type = DType.BOOL
2131 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2132 choice = self.rng.choice([1, 2])
2133 if choice == 1:
2134 cond_shape = [3]
2135 else:
2136 cond_shape = [1, 2]
2137 else:
2138 cond_shape = []
2139 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002140
Kevin Cheng550ccc52021-03-03 11:21:43 -08002141 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002142
2143 # BODY block (input: a, acc, iter, output: a, acc, iter)
2144 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002145 self.ser.addBasicBlock(body_block)
2146
Matthew Haddon630c17c2021-10-14 15:05:41 +01002147 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2148 self.ser.addInputTensor(incorrect_iter)
2149 self.ser.addInputTensor(a)
2150 self.ser.addInputTensor(incorrect_acc)
2151 else:
2152 self.ser.addInputTensor(iter)
2153 self.ser.addInputTensor(a)
2154 self.ser.addInputTensor(acc)
2155
Kevin Cheng550ccc52021-03-03 11:21:43 -08002156 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002157
2158 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002159 iter_body_out = self.ser.addIntermediate(
2160 incorrect_iter.shape, incorrect_iter.dtype
2161 )
2162 acc_body_out = self.ser.addIntermediate(
2163 incorrect_acc.shape, incorrect_acc.dtype
2164 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002165 else:
2166 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2167 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2168
Eric Kunzee5e26762020-10-13 16:11:07 -07002169 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2170 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2171 self.ser.addOutputTensor(iter_body_out)
2172 self.ser.addOutputTensor(a)
2173 self.ser.addOutputTensor(acc_body_out)
2174
Les Bell729b0352021-11-24 10:28:21 +00002175 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002176 self.ser,
2177 validator_fcns,
2178 error_name,
2179 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002180 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002181 ):
2182 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002183
Eric Kunzee5e26762020-10-13 16:11:07 -07002184 return acc_out
2185
Luke Hutton57287132023-02-06 14:54:18 +00002186 def build_fft2d(
2187 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2188 ):
2189 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2190
2191 input_names = [val1.name, val2.name]
2192 pCount, cCount = op["operands"]
2193 num_operands = pCount + cCount
2194
2195 output_names = [res.name for res in results]
2196 output_shapes = [res.shape for res in results]
2197 output_dtypes = [res.dtype for res in results]
2198
2199 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2200 self, error_name, input_names, output_names
2201 )
2202
2203 if not TosaErrorValidator.evValidateErrorIfs(
2204 self.ser,
2205 validator_fcns,
2206 error_name,
2207 op=op,
2208 inverse=inverse,
2209 input1=val1,
2210 input2=val2,
2211 input_shape=val1.shape,
2212 input_dtype=val1.dtype,
2213 output_shape=output_shapes,
2214 output_dtype=output_dtypes,
2215 result_tensors=results,
2216 input_list=input_names,
2217 output_list=output_names,
2218 num_operands=num_operands,
2219 ):
2220 return None
2221
2222 attr = ts.TosaSerializerAttribute()
2223 attr.FFTAttribute(inverse)
2224
2225 self.ser.addOperator(op["op"], input_names, output_names, attr)
2226 return results
2227
Luke Hutton261b7b62023-01-10 14:50:31 +00002228 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2229 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2230
2231 input_names = [val.name]
2232 pCount, cCount = op["operands"]
2233 num_operands = pCount + cCount
2234
2235 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002236 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002237 output_dtypes = [res.dtype for res in results]
2238
2239 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2240 self, error_name, input_names, output_names
2241 )
2242
2243 if not TosaErrorValidator.evValidateErrorIfs(
2244 self.ser,
2245 validator_fcns,
2246 error_name,
2247 op=op,
2248 input_shape=val.shape,
2249 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002250 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002251 output_dtype=output_dtypes,
2252 result_tensors=results,
2253 input_list=input_names,
2254 output_list=output_names,
2255 num_operands=num_operands,
2256 ):
2257 return None
2258
2259 self.ser.addOperator(op["op"], input_names, output_names)
2260 return results
2261
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002262 def create_filter_lists(
2263 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2264 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002265 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2266 default_test_rank_range = range(1, 5)
2267 if not shapeFilter:
2268 shapeFilter = [None]
2269
2270 # Calculate the filters based on what is requested and what the operator allows
2271 rmin, rmax = op["rank"]
2272 if rankFilter is not None:
2273 cleanRankFilter = []
2274 # Ensure rankFilter values are allowed by operator
2275 for rank in rankFilter:
2276 if rank >= rmin and rank <= rmax:
2277 cleanRankFilter.append(rank)
2278 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002279 # Ensure default behaviour is bounded by default range or by operator,
2280 # whichever is the smaller range of ranks.
2281 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002282 cleanRankFilter = (
2283 opRankRange
2284 if len(opRankRange) <= len(default_test_rank_range)
2285 else default_test_rank_range
2286 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002287 else:
2288 cleanRankFilter = range(rmin, rmax + 1)
2289
2290 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002291
Matthew Haddon1c00b712021-10-01 15:51:03 +01002292 if dtypeFilter is not None:
2293 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002294 # Create list of operator dtypes filtered by requested dtypes
2295 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002296 if dtype in dtypeFilter or (
2297 isinstance(dtype, list) and dtype[0] in dtypeFilter
2298 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002299 cleanDtypeFilter.append(dtype)
2300 else:
2301 cleanDtypeFilter = dtypes
2302
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002303 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002304 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002305 "shapeFilter": shapeFilter,
2306 "rankFilter": cleanRankFilter,
2307 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002308 }
2309 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002310 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002311 if validator is not None:
2312 validator_info = validator(check=False, op=op)
2313 else:
2314 return None
2315
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002316 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002317
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002318 # Set parameters as required
2319 if error_arguments["rank"] is not None:
2320 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002321 else:
2322 rankFilter = cleanRankFilter
2323
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002324 if error_arguments["dtype"] is not None:
2325 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002326 else:
2327 dtypeFilter = cleanDtypeFilter
2328
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002329 if error_arguments["shape"] is not None:
2330 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002331 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002332 shapeFilter = shapeFilter[
2333 :2
2334 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002335
2336 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002337 "shapeFilter": shapeFilter,
2338 "rankFilter": rankFilter,
2339 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002340 }
2341 return filterDict
2342
Kevin Cheng550ccc52021-03-03 11:21:43 -08002343 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002344 self,
2345 opName,
2346 shapeFilter=[None],
2347 rankFilter=None,
2348 dtypeFilter=None,
2349 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002350 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002351
2352 try:
2353 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002354 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002355 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002356
2357 # Initialize a new random number generator
2358 self.rng = np.random.default_rng(self.random_seed)
2359
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002360 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002361
Eric Kunzee5e26762020-10-13 16:11:07 -07002362 # Test list consists of a tuple of:
2363 # (opName, testNameStr, dtype, shapeList, argumentsList)
2364 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002365 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002366 error_if_validators = op["error_if_validators"]
2367 else:
2368 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002369
Matthew Haddon1c00b712021-10-01 15:51:03 +01002370 for validator in error_if_validators:
2371 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002372 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002373 else:
2374 error_name = None
2375
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002376 filterDict = self.create_filter_lists(
2377 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2378 )
2379 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002380 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002381 cleanRankFilter = filterDict["rankFilter"]
2382 cleanDtypeFilter = filterDict["dtypeFilter"]
2383 cleanShapeFilter = filterDict["shapeFilter"]
2384 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002385
2386 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002387 for t in cleanDtypeFilter:
2388 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002389 # Filter out by rank
2390 if shape is not None and len(shape) != r:
2391 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002392 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002393 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002394
Matthew Haddon74567092021-07-16 15:38:20 +01002395 shapeStr = self.shapeStr(shapeList[0])
2396 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002397
Matthew Haddon74567092021-07-16 15:38:20 +01002398 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2399 argList = []
2400 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002401 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002402 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002403 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002404
Matthew Haddon74567092021-07-16 15:38:20 +01002405 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002406 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002407 if argStr:
2408 testStr = "{}_{}_{}_{}".format(
2409 opName, shapeStr, typeStr, argStr
2410 )
2411 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002412 testStr = "{}_{}_{}".format(
2413 opName, shapeStr, typeStr
2414 )
2415 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002416 if argStr:
2417 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2418 opName, error_name, shapeStr, typeStr, argStr
2419 )
2420 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002421 testStr = "{}_ERRORIF_{}_{}_{}".format(
2422 opName, error_name, shapeStr, typeStr
2423 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002424
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002425 testList.append(
2426 (opName, testStr, t, error_name, shapeList, args)
2427 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002428
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002429 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002430 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2431 if "invalid_test_validators" in op:
2432 invalid_test_validators = op["invalid_test_validators"]
2433 clean_testList = []
2434 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002435 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002436 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002437 if validator_fcn(
2438 opName=test[0],
2439 input_dtype=test[2],
2440 shapeList=test[4],
2441 args=test[5],
2442 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002443 remove_test = True
2444 if not remove_test:
2445 clean_testList.append(test)
2446 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002447
2448 return testList
2449
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002450 def serializeTest(
2451 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2452 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002453 try:
2454 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002455 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002456 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002457
Jeremy Johnson0c716862023-04-13 17:18:19 +01002458 if self.args.verbose:
2459 print(f"Creating {testStr}")
2460
Eric Kunzee5e26762020-10-13 16:11:07 -07002461 # Create a serializer
2462 self.createSerializer(opName, testStr)
2463
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002464 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002465 if "error_if_validators" in op:
2466 error_if_validators = op["error_if_validators"]
2467 else:
2468 error_if_validators = None
2469
Kevin Cheng550ccc52021-03-03 11:21:43 -08002470 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002471 num_operands = pCount + cCount
2472
2473 if isinstance(dtype_or_dtypeList, list):
2474 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002475 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002476 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002477 else:
2478 dtypeList = [dtype_or_dtypeList] * (num_operands)
2479
Kevin Cheng93a16282021-08-31 16:14:03 -07002480 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002481 assert (
2482 len(shapeList) == num_operands
2483 ), "shapeList length {} must match number of operands {}".format(
2484 len(shapeList), num_operands
2485 )
2486 assert (
2487 len(dtypeList) == num_operands
2488 ), "dtypeList length {} must match number of operands {}".format(
2489 len(dtypeList), num_operands
2490 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002491
2492 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002493 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002494 except KeyError:
2495 qgen = None
2496
2497 # Build the random tensor operands and the test
2498 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002499
Matthew Haddon1c00b712021-10-01 15:51:03 +01002500 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002501 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002502 else:
2503 qinfo = None
2504
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002505 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002506
Matthew Haddon1c00b712021-10-01 15:51:03 +01002507 try:
2508 if error_if_validators is None:
2509 if qinfo is not None:
2510 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2511 else:
2512 resultName = build_fcn(self, op, *tens, *testArgs)
2513 else:
2514 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002515 resultName = build_fcn(
2516 self,
2517 op,
2518 *tens,
2519 *testArgs,
2520 validator_fcns=error_if_validators,
2521 error_name=error_name,
2522 qinfo=qinfo,
2523 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002524 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002525 resultName = build_fcn(
2526 self,
2527 op,
2528 *tens,
2529 *testArgs,
2530 validator_fcns=error_if_validators,
2531 error_name=error_name,
2532 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002533 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002534 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002535 raise e
2536
Les Bell729b0352021-11-24 10:28:21 +00002537 if resultName:
2538 # The test is valid, serialize it
2539 self.serialize("test")
2540 else:
2541 # The test is not valid
2542 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002543
Eric Kunzee5e26762020-10-13 16:11:07 -07002544 def createDynamicOpLists(self):
2545
Jeremy Johnson00423432022-09-12 17:27:37 +01002546 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2547 # Already created these lists (can occur when class is initialized more than once)
2548 return
2549
Eric Kunzee5e26762020-10-13 16:11:07 -07002550 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002551 if not self.args.level8k:
2552 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2553 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2554 else:
2555 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2556 KERNELS_2D = [[1, bigK], [bigK, 2]]
2557 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002558
Kevin Cheng1533b852021-09-01 12:51:58 -07002559 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002560 testName = "conv2d_{}x{}".format(k[0], k[1])
2561 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2562 self.TOSA_OP_LIST[testName]["filter"] = k
2563 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002564
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2566 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2567 "depthwise_conv2d_TEMPLATE"
2568 ].copy()
2569 self.TOSA_OP_LIST[testName]["filter"] = k
2570 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002571
Kevin Cheng550ccc52021-03-03 11:21:43 -08002572 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2573 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2574 "transpose_conv2d_TEMPLATE"
2575 ].copy()
2576 self.TOSA_OP_LIST[testName]["filter"] = k
2577 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002578
Kevin Cheng1533b852021-09-01 12:51:58 -07002579 for k in KERNELS_3D:
2580 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2581 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2582 self.TOSA_OP_LIST[testName]["filter"] = k
2583 self.TOSA_OP_LIST[testName]["template"] = False
2584
Eric Kunzee5e26762020-10-13 16:11:07 -07002585 # Delete any templates after having created any dynamic ops
2586 # This is a two-pass operation because it's bad practice to delete
2587 # keys from dictionaries while iterating
2588 keyList = []
2589 for k in self.TOSA_OP_LIST:
2590 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002591 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002592 keyList.append(k)
2593 continue
2594 except KeyError:
2595 pass
2596
2597 for k in keyList:
2598 del self.TOSA_OP_LIST[k]
2599
2600 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002601 """Fill in default fields for ops if they aren't already specified.
2602 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002603 for op in self.TOSA_OP_LIST:
2604
2605 # Required fields
2606 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002607 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002608 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002609 raise Exception(
2610 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2611 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002612
2613 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002614 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002615 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002616 raise Exception(
2617 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2618 op
2619 )
2620 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002621
2622 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002623 _ = self.TOSA_OP_LIST[op]["types"]
2624 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002625 raise Exception(
2626 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2627 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002628
2629 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002630 _ = self.TOSA_OP_LIST[op]["op"]
2631 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002632 raise Exception(
2633 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2634 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002635
2636 # Put in default rank range, if missing
2637 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002638 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002639 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002640 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002641
2642 # Tensor operator list
2643 # 'op': op name
2644 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002645 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2646 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002647 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2648 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002649 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002650
Kevin Cheng550ccc52021-03-03 11:21:43 -08002651 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002652 TYPE_INT_FP = [
2653 DType.INT8,
2654 DType.INT16,
2655 DType.INT32,
2656 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002657 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002658 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002659 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002660
Kevin Cheng550ccc52021-03-03 11:21:43 -08002661 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002662 TYPE_FI32 = [
2663 DType.FP32,
2664 DType.FP16,
2665 DType.BF16,
2666 DType.INT32,
2667 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002668 TYPE_FIB = [
2669 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002670 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002671 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002672 DType.INT8,
2673 DType.INT16,
2674 DType.INT32,
2675 DType.BOOL,
2676 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002677 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002678
James Ward24dbc422022-10-19 12:20:31 +01002679 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002680
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002681 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002682 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002683 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002684 [DType.INT8, DType.INT8, DType.INT32],
2685 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002686 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002687 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002688 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002689 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002690 ]
2691
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002692 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002693
2694 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002695 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002696 "argmax": {
2697 "op": Op.ARGMAX,
2698 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002699 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002700 "build_fcn": (
2701 build_argmax,
2702 TosaTensorGen.tgBasic,
2703 TosaTensorValuesGen.tvgDefault,
2704 TosaArgGen.agAxis,
2705 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002706 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002707 "error_if_validators": (
2708 TosaErrorValidator.evAxisSmallerZero,
2709 TosaErrorValidator.evAxisLargerRank,
2710 TosaErrorValidator.evArgmaxOutputRankMismatch,
2711 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2712 TosaErrorValidator.evWrongRank,
2713 TosaErrorValidator.evWrongInputType,
2714 TosaErrorValidator.evWrongOutputType,
2715 TosaErrorValidator.evWrongInputList,
2716 TosaErrorValidator.evWrongOutputList,
2717 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002718 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002719 "avg_pool2d": {
2720 "op": Op.AVG_POOL2D,
2721 "operands": (1, 0),
2722 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002723 "build_fcn": (
2724 build_pool2d,
2725 TosaTensorGen.tgNHWC,
2726 TosaTensorValuesGen.tvgDefault,
2727 TosaArgGen.agPooling,
2728 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002729 "qgen": TosaQuantGen.qgUnary,
2730 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002731 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002732 "error_if_validators": (
2733 TosaErrorValidator.evKernelSmallerOne,
2734 TosaErrorValidator.evStrideSmallerOne,
2735 TosaErrorValidator.evPadSmallerZero,
2736 TosaErrorValidator.evWrongRank,
2737 TosaErrorValidator.evWrongInputType,
2738 TosaErrorValidator.evWrongOutputType,
2739 TosaErrorValidator.evWrongInputList,
2740 TosaErrorValidator.evWrongOutputList,
2741 TosaErrorValidator.evInputZeroPointNotZero,
2742 TosaErrorValidator.evOutputZeroPointNotZero,
2743 TosaErrorValidator.evPadLargerEqualKernel,
2744 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002745 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002746 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002747 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002748 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002749 "conv2d_TEMPLATE": {
2750 "op": Op.CONV2D,
2751 "operands": (1, 2),
2752 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002753 "build_fcn": (
2754 build_conv2d,
2755 TosaTensorGen.tgConv2D,
2756 TosaTensorValuesGen.tvgDefault,
2757 TosaArgGen.agConv,
2758 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002759 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002760 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002761 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2762 "error_if_validators": (
2763 TosaErrorValidator.evWrongInputType,
2764 TosaErrorValidator.evWrongOutputType,
2765 TosaErrorValidator.evWrongInputList,
2766 TosaErrorValidator.evWrongOutputList,
2767 TosaErrorValidator.evInputZeroPointNotZero,
2768 TosaErrorValidator.evWeightZeroPointNotZero,
2769 TosaErrorValidator.evPadSmallerZero,
2770 TosaErrorValidator.evStrideSmallerOne,
2771 TosaErrorValidator.evDilationSmallerOne,
2772 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002773 TosaErrorValidator.evConvOutputShapeMismatch,
2774 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002775 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002776 "template": True,
2777 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002778 # Templated operator. Filled in by createDynamicOpLists
2779 "conv3d_TEMPLATE": {
2780 "op": Op.CONV3D,
2781 "operands": (1, 2),
2782 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002783 "build_fcn": (
2784 build_conv3d,
2785 TosaTensorGen.tgConv3D,
2786 TosaTensorValuesGen.tvgDefault,
2787 TosaArgGen.agConv,
2788 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002789 "qgen": TosaQuantGen.qgConv,
2790 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002791 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2792 "error_if_validators": (
2793 TosaErrorValidator.evWrongInputType,
2794 TosaErrorValidator.evWrongOutputType,
2795 TosaErrorValidator.evWrongInputList,
2796 TosaErrorValidator.evWrongOutputList,
2797 TosaErrorValidator.evInputZeroPointNotZero,
2798 TosaErrorValidator.evWeightZeroPointNotZero,
2799 TosaErrorValidator.evPadSmallerZero,
2800 TosaErrorValidator.evStrideSmallerOne,
2801 TosaErrorValidator.evDilationSmallerOne,
2802 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002803 TosaErrorValidator.evConvOutputShapeMismatch,
2804 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002805 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002806 "template": True,
2807 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002808 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002809 "depthwise_conv2d_TEMPLATE": {
2810 "op": Op.DEPTHWISE_CONV2D,
2811 "operands": (1, 2),
2812 "filter": [1, 1],
2813 "rank": (4, 4),
2814 "build_fcn": (
2815 build_depthwise_conv2d,
2816 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002817 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002818 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002819 ),
2820 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002821 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002822 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2823 "error_if_validators": (
2824 TosaErrorValidator.evWrongInputType,
2825 TosaErrorValidator.evWrongOutputType,
2826 TosaErrorValidator.evWrongInputList,
2827 TosaErrorValidator.evWrongOutputList,
2828 TosaErrorValidator.evInputZeroPointNotZero,
2829 TosaErrorValidator.evWeightZeroPointNotZero,
2830 TosaErrorValidator.evPadSmallerZero,
2831 TosaErrorValidator.evStrideSmallerOne,
2832 TosaErrorValidator.evDilationSmallerOne,
2833 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002834 TosaErrorValidator.evConvOutputShapeMismatch,
2835 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002836 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002837 "template": True,
2838 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002839 "fully_connected": {
2840 "op": Op.FULLY_CONNECTED,
2841 "operands": (1, 2),
2842 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002843 "build_fcn": (
2844 build_fully_connected,
2845 TosaTensorGen.tgFullyConnected,
2846 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002847 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002848 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002849 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002850 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002851 "error_if_validators": (
2852 TosaErrorValidator.evInputZeroPointNotZero,
2853 TosaErrorValidator.evWeightZeroPointNotZero,
2854 TosaErrorValidator.evWrongRank,
2855 TosaErrorValidator.evWrongInputType,
2856 TosaErrorValidator.evWrongOutputType,
2857 TosaErrorValidator.evWrongInputList,
2858 TosaErrorValidator.evWrongOutputList,
2859 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002860 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002861 "matmul": {
2862 "op": Op.MATMUL,
2863 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002864 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002865 "build_fcn": (
2866 build_matmul,
2867 TosaTensorGen.tgMatmul,
2868 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002869 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002870 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002871 "qgen": TosaQuantGen.qgMatmul,
2872 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002873 "error_if_validators": (
2874 TosaErrorValidator.evInputZeroPointNotZero,
2875 TosaErrorValidator.evWrongRank,
2876 TosaErrorValidator.evWrongInputType,
2877 TosaErrorValidator.evWrongOutputType,
2878 TosaErrorValidator.evWrongInputList,
2879 TosaErrorValidator.evWrongOutputList,
2880 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002881 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002882 "max_pool2d": {
2883 "op": Op.MAX_POOL2D,
2884 "operands": (1, 0),
2885 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002886 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002887 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002888 TosaTensorGen.tgNHWC,
2889 TosaTensorValuesGen.tvgDefault,
2890 TosaArgGen.agPooling,
2891 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002892 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002893 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002894 "error_if_validators": (
2895 TosaErrorValidator.evKernelSmallerOne,
2896 TosaErrorValidator.evStrideSmallerOne,
2897 TosaErrorValidator.evPadSmallerZero,
2898 TosaErrorValidator.evWrongRank,
2899 TosaErrorValidator.evWrongInputType,
2900 TosaErrorValidator.evWrongOutputType,
2901 TosaErrorValidator.evWrongInputList,
2902 TosaErrorValidator.evWrongOutputList,
2903 TosaErrorValidator.evPadLargerEqualKernel,
2904 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002905 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002906 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002907 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002908 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002909 "transpose_conv2d_TEMPLATE": {
2910 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002911 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002912 "rank": (4, 4),
2913 "build_fcn": (
2914 build_transpose_conv2d,
2915 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002916 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002917 TosaArgGen.agTransposeConv2D,
2918 ),
2919 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002920 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002921 "invalid_test_validators": (
2922 TosaInvalidValidator.ivHeightWidthInvalid,
2923 TosaInvalidValidator.ivNonPositiveOutputShape,
2924 ),
2925 "error_if_validators": (
2926 TosaErrorValidator.evWrongInputType,
2927 TosaErrorValidator.evWrongOutputType,
2928 TosaErrorValidator.evWrongInputList,
2929 TosaErrorValidator.evWrongOutputList,
2930 TosaErrorValidator.evInputZeroPointNotZero,
2931 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002932 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002933 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002934 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002935 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002936 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002937 "template": True,
2938 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002939 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002940 "clamp": {
2941 "op": Op.CLAMP,
2942 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002943 "build_fcn": (
2944 build_clamp,
2945 TosaTensorGen.tgBasic,
2946 TosaTensorValuesGen.tvgDefault,
2947 None,
2948 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002949 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002950 "error_if_validators": (
2951 TosaErrorValidator.evMaxSmallerMin,
2952 TosaErrorValidator.evWrongInputType,
2953 TosaErrorValidator.evWrongOutputType,
2954 TosaErrorValidator.evWrongInputList,
2955 TosaErrorValidator.evWrongOutputList,
2956 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002957 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002958 "sigmoid": {
2959 "op": Op.SIGMOID,
2960 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002961 "build_fcn": (
2962 build_sigmoid,
2963 TosaTensorGen.tgBasic,
2964 TosaTensorValuesGen.tvgDefault,
2965 None,
2966 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002967 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002968 "error_if_validators": (
2969 TosaErrorValidator.evWrongInputType,
2970 TosaErrorValidator.evWrongOutputType,
2971 TosaErrorValidator.evWrongInputList,
2972 TosaErrorValidator.evWrongOutputList,
2973 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002974 },
2975 "tanh": {
2976 "op": Op.TANH,
2977 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002978 "build_fcn": (
2979 build_tanh,
2980 TosaTensorGen.tgBasic,
2981 TosaTensorValuesGen.tvgDefault,
2982 None,
2983 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002984 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002985 "error_if_validators": (
2986 TosaErrorValidator.evWrongInputType,
2987 TosaErrorValidator.evWrongOutputType,
2988 TosaErrorValidator.evWrongInputList,
2989 TosaErrorValidator.evWrongOutputList,
2990 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002991 },
Won Jeon78155c62023-06-10 00:20:04 +00002992 "erf": {
2993 "op": Op.ERF,
2994 "operands": (1, 0),
2995 "build_fcn": (
2996 build_erf,
2997 TosaTensorGen.tgBasic,
2998 TosaTensorValuesGen.tvgDefault,
2999 None,
3000 ),
3001 "types": TYPE_FP,
3002 "error_if_validators": (
3003 TosaErrorValidator.evWrongInputType,
3004 TosaErrorValidator.evWrongOutputType,
3005 TosaErrorValidator.evWrongInputList,
3006 TosaErrorValidator.evWrongOutputList,
3007 ),
3008 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003009 # Elementwise Binary Operators
3010 "add": {
3011 "op": Op.ADD,
3012 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003013 "build_fcn": (
3014 build_binary_broadcast,
3015 TosaTensorGen.tgBroadcastFuzz,
3016 TosaTensorValuesGen.tvgAddSub,
3017 None,
3018 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003019 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003020 "error_if_validators": (
3021 TosaErrorValidator.evRankMismatch,
3022 TosaErrorValidator.evWrongInputType,
3023 TosaErrorValidator.evWrongOutputType,
3024 TosaErrorValidator.evWrongInputList,
3025 TosaErrorValidator.evWrongOutputList,
3026 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003027 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003028 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003029 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003030 "arithmetic_right_shift": {
3031 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3032 "operands": (2, 0),
3033 "build_fcn": (
3034 build_arithmetic_right_shift,
3035 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003036 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003037 TosaArgGen.agArithmeticRightShift,
3038 ),
3039 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003040 "error_if_validators": (
3041 TosaErrorValidator.evRankMismatch,
3042 TosaErrorValidator.evWrongInputType,
3043 TosaErrorValidator.evWrongOutputType,
3044 TosaErrorValidator.evWrongInputList,
3045 TosaErrorValidator.evWrongOutputList,
3046 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003047 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003048 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003049 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003050 "bitwise_and": {
3051 "op": Op.BITWISE_AND,
3052 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003053 "build_fcn": (
3054 build_binary_broadcast,
3055 TosaTensorGen.tgBroadcastFuzz,
3056 TosaTensorValuesGen.tvgDefault,
3057 None,
3058 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003059 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003060 "error_if_validators": (
3061 TosaErrorValidator.evRankMismatch,
3062 TosaErrorValidator.evWrongInputType,
3063 TosaErrorValidator.evWrongOutputType,
3064 TosaErrorValidator.evWrongInputList,
3065 TosaErrorValidator.evWrongOutputList,
3066 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003067 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003068 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003069 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 "bitwise_or": {
3071 "op": Op.BITWISE_OR,
3072 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003073 "build_fcn": (
3074 build_binary_broadcast,
3075 TosaTensorGen.tgBroadcastFuzz,
3076 TosaTensorValuesGen.tvgDefault,
3077 None,
3078 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003079 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003080 "error_if_validators": (
3081 TosaErrorValidator.evRankMismatch,
3082 TosaErrorValidator.evWrongInputType,
3083 TosaErrorValidator.evWrongOutputType,
3084 TosaErrorValidator.evWrongInputList,
3085 TosaErrorValidator.evWrongOutputList,
3086 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003087 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003088 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003089 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003090 "bitwise_xor": {
3091 "op": Op.BITWISE_XOR,
3092 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003093 "build_fcn": (
3094 build_binary_broadcast,
3095 TosaTensorGen.tgBroadcastFuzz,
3096 TosaTensorValuesGen.tvgDefault,
3097 None,
3098 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003099 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003100 "error_if_validators": (
3101 TosaErrorValidator.evRankMismatch,
3102 TosaErrorValidator.evWrongInputType,
3103 TosaErrorValidator.evWrongOutputType,
3104 TosaErrorValidator.evWrongInputList,
3105 TosaErrorValidator.evWrongOutputList,
3106 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003107 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003108 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003109 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003110 "intdiv": {
3111 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003112 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003113 "build_fcn": (
3114 build_binary_broadcast,
3115 TosaTensorGen.tgBroadcastFuzz,
3116 TosaTensorValuesGen.tvgIntDiv,
3117 None,
3118 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003119 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003120 "error_if_validators": (
3121 TosaErrorValidator.evRankMismatch,
3122 TosaErrorValidator.evWrongInputType,
3123 TosaErrorValidator.evWrongOutputType,
3124 TosaErrorValidator.evWrongInputList,
3125 TosaErrorValidator.evWrongOutputList,
3126 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003127 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003128 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003129 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003130 "logical_and": {
3131 "op": Op.LOGICAL_AND,
3132 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003133 "build_fcn": (
3134 build_binary_broadcast,
3135 TosaTensorGen.tgBroadcastFuzz,
3136 TosaTensorValuesGen.tvgDefault,
3137 None,
3138 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003139 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003140 "error_if_validators": (
3141 TosaErrorValidator.evRankMismatch,
3142 TosaErrorValidator.evWrongInputType,
3143 TosaErrorValidator.evWrongOutputType,
3144 TosaErrorValidator.evWrongInputList,
3145 TosaErrorValidator.evWrongOutputList,
3146 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003147 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003148 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003149 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003150 "logical_left_shift": {
3151 "op": Op.LOGICAL_LEFT_SHIFT,
3152 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003153 "build_fcn": (
3154 build_binary_broadcast,
3155 TosaTensorGen.tgBroadcastFuzz,
3156 TosaTensorValuesGen.tvgLogicalShift,
3157 None,
3158 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003159 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003160 "error_if_validators": (
3161 TosaErrorValidator.evRankMismatch,
3162 TosaErrorValidator.evWrongInputType,
3163 TosaErrorValidator.evWrongOutputType,
3164 TosaErrorValidator.evWrongInputList,
3165 TosaErrorValidator.evWrongOutputList,
3166 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003167 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003168 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003169 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003170 "logical_right_shift": {
3171 "op": Op.LOGICAL_RIGHT_SHIFT,
3172 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003173 "build_fcn": (
3174 build_binary_broadcast,
3175 TosaTensorGen.tgBroadcastFuzz,
3176 TosaTensorValuesGen.tvgLogicalShift,
3177 None,
3178 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003179 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003180 "error_if_validators": (
3181 TosaErrorValidator.evRankMismatch,
3182 TosaErrorValidator.evWrongInputType,
3183 TosaErrorValidator.evWrongOutputType,
3184 TosaErrorValidator.evWrongInputList,
3185 TosaErrorValidator.evWrongOutputList,
3186 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003187 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003188 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003189 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003190 "logical_or": {
3191 "op": Op.LOGICAL_OR,
3192 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003193 "build_fcn": (
3194 build_binary_broadcast,
3195 TosaTensorGen.tgBroadcastFuzz,
3196 TosaTensorValuesGen.tvgDefault,
3197 None,
3198 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003199 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003200 "error_if_validators": (
3201 TosaErrorValidator.evRankMismatch,
3202 TosaErrorValidator.evWrongInputType,
3203 TosaErrorValidator.evWrongOutputType,
3204 TosaErrorValidator.evWrongInputList,
3205 TosaErrorValidator.evWrongOutputList,
3206 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003207 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003208 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003209 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003210 "logical_xor": {
3211 "op": Op.LOGICAL_XOR,
3212 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003213 "build_fcn": (
3214 build_binary_broadcast,
3215 TosaTensorGen.tgBroadcastFuzz,
3216 TosaTensorValuesGen.tvgDefault,
3217 None,
3218 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003219 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003220 "error_if_validators": (
3221 TosaErrorValidator.evRankMismatch,
3222 TosaErrorValidator.evWrongInputType,
3223 TosaErrorValidator.evWrongOutputType,
3224 TosaErrorValidator.evWrongInputList,
3225 TosaErrorValidator.evWrongOutputList,
3226 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003227 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003228 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003229 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003230 "maximum": {
3231 "op": Op.MAXIMUM,
3232 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003233 "build_fcn": (
3234 build_binary_broadcast,
3235 TosaTensorGen.tgBroadcastFuzz,
3236 TosaTensorValuesGen.tvgDefault,
3237 None,
3238 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003239 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003240 "error_if_validators": (
3241 TosaErrorValidator.evRankMismatch,
3242 TosaErrorValidator.evWrongInputType,
3243 TosaErrorValidator.evWrongOutputType,
3244 TosaErrorValidator.evWrongInputList,
3245 TosaErrorValidator.evWrongOutputList,
3246 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003247 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003248 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003249 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003250 "minimum": {
3251 "op": Op.MINIMUM,
3252 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003253 "build_fcn": (
3254 build_binary_broadcast,
3255 TosaTensorGen.tgBroadcastFuzz,
3256 TosaTensorValuesGen.tvgDefault,
3257 None,
3258 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003259 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003260 "error_if_validators": (
3261 TosaErrorValidator.evRankMismatch,
3262 TosaErrorValidator.evWrongInputType,
3263 TosaErrorValidator.evWrongOutputType,
3264 TosaErrorValidator.evWrongInputList,
3265 TosaErrorValidator.evWrongOutputList,
3266 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003267 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003268 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003269 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003270 "mul": {
3271 "op": Op.MUL,
3272 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003273 "build_fcn": (
3274 build_mul,
3275 TosaTensorGen.tgBroadcastFuzz,
3276 TosaTensorValuesGen.tvgMul,
3277 TosaArgGen.agMul,
3278 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003279 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003280 "error_if_validators": (
3281 TosaErrorValidator.evWrongInputType,
3282 TosaErrorValidator.evWrongOutputType,
3283 TosaErrorValidator.evWrongInputList,
3284 TosaErrorValidator.evWrongOutputList,
3285 TosaErrorValidator.evRankMismatch,
3286 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003287 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003288 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003289 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003290 "pow": {
3291 "op": Op.POW,
3292 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003293 "build_fcn": (
3294 build_binary_broadcast,
3295 TosaTensorGen.tgBroadcastFuzz,
3296 TosaTensorValuesGen.tvgDefault,
3297 None,
3298 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003300 "error_if_validators": (
3301 TosaErrorValidator.evRankMismatch,
3302 TosaErrorValidator.evWrongInputType,
3303 TosaErrorValidator.evWrongOutputType,
3304 TosaErrorValidator.evWrongInputList,
3305 TosaErrorValidator.evWrongOutputList,
3306 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003307 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003308 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003310 "sub": {
3311 "op": Op.SUB,
3312 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003313 "build_fcn": (
3314 build_binary_broadcast,
3315 TosaTensorGen.tgBroadcastFuzz,
3316 TosaTensorValuesGen.tvgAddSub,
3317 None,
3318 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003319 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003320 "error_if_validators": (
3321 TosaErrorValidator.evRankMismatch,
3322 TosaErrorValidator.evWrongInputType,
3323 TosaErrorValidator.evWrongOutputType,
3324 TosaErrorValidator.evWrongInputList,
3325 TosaErrorValidator.evWrongOutputList,
3326 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003327 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003328 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003329 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003330 "table": {
3331 "op": Op.TABLE,
3332 # Use the automatic generation functions to create the input array
3333 # but create the table tensor in the build function, as it may be
3334 # a different type from the input
3335 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003336 "build_fcn": (
3337 build_table,
3338 TosaTensorGen.tgBasic,
3339 TosaTensorValuesGen.tvgDefault,
3340 TosaArgGen.agTable,
3341 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003342 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003343 "error_if_validators": (
3344 TosaErrorValidator.evWrongInputType,
3345 TosaErrorValidator.evWrongOutputType,
3346 TosaErrorValidator.evWrongInputList,
3347 TosaErrorValidator.evWrongOutputList,
3348 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003349 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003350 # Elementwise Unary operators
3351 "abs": {
3352 "op": Op.ABS,
3353 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003354 "build_fcn": (
3355 build_unary,
3356 TosaTensorGen.tgBasic,
3357 TosaTensorValuesGen.tvgDefault,
3358 None,
3359 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003361 "error_if_validators": (
3362 TosaErrorValidator.evWrongInputType,
3363 TosaErrorValidator.evWrongOutputType,
3364 TosaErrorValidator.evWrongInputList,
3365 TosaErrorValidator.evWrongOutputList,
3366 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003367 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003368 "bitwise_not": {
3369 "op": Op.BITWISE_NOT,
3370 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003371 "build_fcn": (
3372 build_unary,
3373 TosaTensorGen.tgBasic,
3374 TosaTensorValuesGen.tvgDefault,
3375 None,
3376 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003377 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003378 "error_if_validators": (
3379 TosaErrorValidator.evWrongInputType,
3380 TosaErrorValidator.evWrongOutputType,
3381 TosaErrorValidator.evWrongInputList,
3382 TosaErrorValidator.evWrongOutputList,
3383 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003385 "ceil": {
3386 "op": Op.CEIL,
3387 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003388 "build_fcn": (
3389 build_unary,
3390 TosaTensorGen.tgBasic,
3391 TosaTensorValuesGen.tvgDefault,
3392 None,
3393 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003394 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003395 "error_if_validators": (
3396 TosaErrorValidator.evWrongInputType,
3397 TosaErrorValidator.evWrongOutputType,
3398 TosaErrorValidator.evWrongInputList,
3399 TosaErrorValidator.evWrongOutputList,
3400 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003401 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003402 "clz": {
3403 "op": Op.CLZ,
3404 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003405 "build_fcn": (
3406 build_unary,
3407 TosaTensorGen.tgBasic,
3408 TosaTensorValuesGen.tvgDefault,
3409 None,
3410 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003411 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003412 "error_if_validators": (
3413 TosaErrorValidator.evWrongInputType,
3414 TosaErrorValidator.evWrongOutputType,
3415 TosaErrorValidator.evWrongInputList,
3416 TosaErrorValidator.evWrongOutputList,
3417 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003419 "exp": {
3420 "op": Op.EXP,
3421 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003422 "build_fcn": (
3423 build_unary,
3424 TosaTensorGen.tgBasic,
3425 TosaTensorValuesGen.tvgDefault,
3426 None,
3427 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003428 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003429 "error_if_validators": (
3430 TosaErrorValidator.evWrongInputType,
3431 TosaErrorValidator.evWrongOutputType,
3432 TosaErrorValidator.evWrongInputList,
3433 TosaErrorValidator.evWrongOutputList,
3434 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003435 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003436 "floor": {
3437 "op": Op.FLOOR,
3438 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 "build_fcn": (
3440 build_unary,
3441 TosaTensorGen.tgBasic,
3442 TosaTensorValuesGen.tvgDefault,
3443 None,
3444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003446 "error_if_validators": (
3447 TosaErrorValidator.evWrongInputType,
3448 TosaErrorValidator.evWrongOutputType,
3449 TosaErrorValidator.evWrongInputList,
3450 TosaErrorValidator.evWrongOutputList,
3451 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003452 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 "log": {
3454 "op": Op.LOG,
3455 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003456 "build_fcn": (
3457 build_unary,
3458 TosaTensorGen.tgBasic,
3459 TosaTensorValuesGen.tvgDefault,
3460 None,
3461 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003462 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003463 "error_if_validators": (
3464 TosaErrorValidator.evWrongInputType,
3465 TosaErrorValidator.evWrongOutputType,
3466 TosaErrorValidator.evWrongInputList,
3467 TosaErrorValidator.evWrongOutputList,
3468 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003469 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003470 "logical_not": {
3471 "op": Op.LOGICAL_NOT,
3472 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003473 "build_fcn": (
3474 build_unary,
3475 TosaTensorGen.tgBasic,
3476 TosaTensorValuesGen.tvgDefault,
3477 None,
3478 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003479 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003480 "error_if_validators": (
3481 TosaErrorValidator.evWrongInputType,
3482 TosaErrorValidator.evWrongOutputType,
3483 TosaErrorValidator.evWrongInputList,
3484 TosaErrorValidator.evWrongOutputList,
3485 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003486 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003487 "negate": {
3488 "op": Op.NEGATE,
3489 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003490 "build_fcn": (
3491 build_unary,
3492 TosaTensorGen.tgBasic,
3493 TosaTensorValuesGen.tvgNegate,
3494 None,
3495 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003496 "qgen": TosaQuantGen.qgUnary,
3497 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003498 "error_if_validators": (
3499 TosaErrorValidator.evInputZeroPointNotZero,
3500 TosaErrorValidator.evOutputZeroPointNotZero,
3501 TosaErrorValidator.evWrongInputType,
3502 TosaErrorValidator.evWrongOutputType,
3503 TosaErrorValidator.evWrongInputList,
3504 TosaErrorValidator.evWrongOutputList,
3505 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003506 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003507 "reciprocal": {
3508 "op": Op.RECIPROCAL,
3509 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003510 "build_fcn": (
3511 build_unary,
3512 TosaTensorGen.tgBasic,
3513 TosaTensorValuesGen.tvgDefault,
3514 None,
3515 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003516 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003517 "error_if_validators": (
3518 TosaErrorValidator.evWrongInputType,
3519 TosaErrorValidator.evWrongOutputType,
3520 TosaErrorValidator.evWrongInputList,
3521 TosaErrorValidator.evWrongOutputList,
3522 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003523 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003524 "rsqrt": {
3525 "op": Op.RSQRT,
3526 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003527 "build_fcn": (
3528 build_unary,
3529 TosaTensorGen.tgBasic,
3530 TosaTensorValuesGen.tvgDefault,
3531 None,
3532 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003533 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003534 "error_if_validators": (
3535 TosaErrorValidator.evWrongInputType,
3536 TosaErrorValidator.evWrongOutputType,
3537 TosaErrorValidator.evWrongInputList,
3538 TosaErrorValidator.evWrongOutputList,
3539 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003540 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003541 # Elementwise Ternary operators
3542 "select": {
3543 "op": Op.SELECT,
3544 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003545 "build_fcn": (
3546 build_select,
3547 TosaTensorGen.tgBroadcastFuzz,
3548 TosaTensorValuesGen.tvgSelect,
3549 None,
3550 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003551 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003552 "error_if_validators": (
3553 TosaErrorValidator.evRankMismatch,
3554 TosaErrorValidator.evWrongInputType,
3555 TosaErrorValidator.evWrongOutputType,
3556 TosaErrorValidator.evWrongInputList,
3557 TosaErrorValidator.evWrongOutputList,
3558 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003559 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003560 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003561 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003562 # Comparison operators
3563 "equal": {
3564 "op": Op.EQUAL,
3565 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003566 "build_fcn": (
3567 build_comparison,
3568 TosaTensorGen.tgBroadcastFuzz,
3569 TosaTensorValuesGen.tvgEqual,
3570 None,
3571 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003572 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003573 "error_if_validators": (
3574 TosaErrorValidator.evRankMismatch,
3575 TosaErrorValidator.evWrongInputType,
3576 TosaErrorValidator.evWrongOutputType,
3577 TosaErrorValidator.evWrongInputList,
3578 TosaErrorValidator.evWrongOutputList,
3579 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003580 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003581 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003583 "greater_equal": {
3584 "op": Op.GREATER_EQUAL,
3585 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003586 "build_fcn": (
3587 build_comparison,
3588 TosaTensorGen.tgBroadcastFuzz,
3589 TosaTensorValuesGen.tvgDefault,
3590 None,
3591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003592 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003593 "error_if_validators": (
3594 TosaErrorValidator.evRankMismatch,
3595 TosaErrorValidator.evWrongInputType,
3596 TosaErrorValidator.evWrongOutputType,
3597 TosaErrorValidator.evWrongInputList,
3598 TosaErrorValidator.evWrongOutputList,
3599 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003600 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003601 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003602 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003603 "greater": {
3604 "op": Op.GREATER,
3605 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003606 "build_fcn": (
3607 build_comparison,
3608 TosaTensorGen.tgBroadcastFuzz,
3609 TosaTensorValuesGen.tvgDefault,
3610 None,
3611 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003612 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003613 "error_if_validators": (
3614 TosaErrorValidator.evRankMismatch,
3615 TosaErrorValidator.evWrongInputType,
3616 TosaErrorValidator.evWrongOutputType,
3617 TosaErrorValidator.evWrongInputList,
3618 TosaErrorValidator.evWrongOutputList,
3619 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003620 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003621 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003622 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003623 # Reduction operators
3624 "reduce_all": {
3625 "op": Op.REDUCE_ALL,
3626 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003627 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003628 "build_fcn": (
3629 build_reduce,
3630 TosaTensorGen.tgBasic,
3631 TosaTensorValuesGen.tvgDefault,
3632 TosaArgGen.agAxis,
3633 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003634 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003635 "error_if_validators": (
3636 TosaErrorValidator.evAxisLargerRank,
3637 TosaErrorValidator.evAxisSmallerZero,
3638 TosaErrorValidator.evShapeOfAxisNotOne,
3639 TosaErrorValidator.evWrongInputType,
3640 TosaErrorValidator.evWrongOutputType,
3641 TosaErrorValidator.evWrongRank,
3642 TosaErrorValidator.evWrongInputList,
3643 TosaErrorValidator.evWrongOutputList,
3644 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003645 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003646 "reduce_any": {
3647 "op": Op.REDUCE_ANY,
3648 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003649 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003650 "build_fcn": (
3651 build_reduce,
3652 TosaTensorGen.tgBasic,
3653 TosaTensorValuesGen.tvgDefault,
3654 TosaArgGen.agAxis,
3655 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003656 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003657 "error_if_validators": (
3658 TosaErrorValidator.evAxisLargerRank,
3659 TosaErrorValidator.evAxisSmallerZero,
3660 TosaErrorValidator.evShapeOfAxisNotOne,
3661 TosaErrorValidator.evWrongInputType,
3662 TosaErrorValidator.evWrongOutputType,
3663 TosaErrorValidator.evWrongRank,
3664 TosaErrorValidator.evWrongInputList,
3665 TosaErrorValidator.evWrongOutputList,
3666 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 "reduce_max": {
3669 "op": Op.REDUCE_MAX,
3670 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003671 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 "build_fcn": (
3673 build_reduce,
3674 TosaTensorGen.tgBasic,
3675 TosaTensorValuesGen.tvgDefault,
3676 TosaArgGen.agAxis,
3677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 "error_if_validators": (
3680 TosaErrorValidator.evAxisLargerRank,
3681 TosaErrorValidator.evAxisSmallerZero,
3682 TosaErrorValidator.evShapeOfAxisNotOne,
3683 TosaErrorValidator.evWrongInputType,
3684 TosaErrorValidator.evWrongOutputType,
3685 TosaErrorValidator.evWrongRank,
3686 TosaErrorValidator.evWrongInputList,
3687 TosaErrorValidator.evWrongOutputList,
3688 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003690 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003691 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003692 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003693 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003694 "build_fcn": (
3695 build_reduce,
3696 TosaTensorGen.tgBasic,
3697 TosaTensorValuesGen.tvgDefault,
3698 TosaArgGen.agAxis,
3699 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003701 "error_if_validators": (
3702 TosaErrorValidator.evAxisLargerRank,
3703 TosaErrorValidator.evAxisSmallerZero,
3704 TosaErrorValidator.evShapeOfAxisNotOne,
3705 TosaErrorValidator.evWrongInputType,
3706 TosaErrorValidator.evWrongOutputType,
3707 TosaErrorValidator.evWrongRank,
3708 TosaErrorValidator.evWrongInputList,
3709 TosaErrorValidator.evWrongOutputList,
3710 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003712 "reduce_product": {
3713 "op": Op.REDUCE_PRODUCT,
3714 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003715 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003716 "build_fcn": (
3717 build_reduce,
3718 TosaTensorGen.tgBasic,
3719 TosaTensorValuesGen.tvgDefault,
3720 TosaArgGen.agAxis,
3721 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003722 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003723 "error_if_validators": (
3724 TosaErrorValidator.evAxisLargerRank,
3725 TosaErrorValidator.evAxisSmallerZero,
3726 TosaErrorValidator.evShapeOfAxisNotOne,
3727 TosaErrorValidator.evWrongInputType,
3728 TosaErrorValidator.evWrongOutputType,
3729 TosaErrorValidator.evWrongRank,
3730 TosaErrorValidator.evWrongInputList,
3731 TosaErrorValidator.evWrongOutputList,
3732 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003734 "reduce_sum": {
3735 "op": Op.REDUCE_SUM,
3736 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003737 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003738 "build_fcn": (
3739 build_reduce,
3740 TosaTensorGen.tgBasic,
3741 TosaTensorValuesGen.tvgReduceSum,
3742 TosaArgGen.agAxis,
3743 ),
James Ward24dbc422022-10-19 12:20:31 +01003744 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003745 "error_if_validators": (
3746 TosaErrorValidator.evAxisLargerRank,
3747 TosaErrorValidator.evAxisSmallerZero,
3748 TosaErrorValidator.evShapeOfAxisNotOne,
3749 TosaErrorValidator.evWrongInputType,
3750 TosaErrorValidator.evWrongOutputType,
3751 TosaErrorValidator.evWrongRank,
3752 TosaErrorValidator.evWrongInputList,
3753 TosaErrorValidator.evWrongOutputList,
3754 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003755 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003756 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003757 "concat": {
3758 "op": Op.CONCAT,
3759 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003760 "build_fcn": (
3761 build_concat,
3762 TosaTensorGen.tgConcat,
3763 TosaTensorValuesGen.tvgConcat,
3764 TosaArgGen.agAxis,
3765 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003766 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003767 "error_if_validators": (
3768 TosaErrorValidator.evAxisLargerRank,
3769 TosaErrorValidator.evAxisSmallerZero,
3770 TosaErrorValidator.evConcatInputRankMismatch,
3771 TosaErrorValidator.evConcatShapeSumMismatch,
3772 TosaErrorValidator.evConcatInputDimMismatch,
3773 TosaErrorValidator.evWrongInputType,
3774 TosaErrorValidator.evWrongOutputType,
3775 TosaErrorValidator.evWrongOutputList,
3776 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003777 },
3778 "pad": {
3779 "op": Op.PAD,
3780 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003781 "build_fcn": (
3782 build_pad,
3783 TosaTensorGen.tgBasic,
3784 TosaTensorValuesGen.tvgDefault,
3785 TosaArgGen.agPad,
3786 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003787 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003788 "error_if_validators": (
3789 TosaErrorValidator.evWrongInputType,
3790 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003791 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 TosaErrorValidator.evWrongOutputType,
3793 TosaErrorValidator.evWrongInputList,
3794 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003795 TosaErrorValidator.evRankMismatch,
3796 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003797 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003798 },
Won Jeona21b2e82023-08-10 10:33:01 +00003799 "dim": {
3800 "op": Op.DIM,
3801 "operands": (1, 0),
3802 "build_fcn": (
3803 build_dim,
3804 TosaTensorGen.tgBasic,
3805 TosaTensorValuesGen.tvgDefault,
3806 TosaArgGen.agAxis,
3807 ),
3808 "types": TYPE_FIB,
3809 "error_if_validators": (
3810 TosaErrorValidator.evAxisLargerRank,
3811 TosaErrorValidator.evAxisSmallerZero,
3812 TosaErrorValidator.evWrongInputType,
3813 TosaErrorValidator.evWrongInputList,
3814 TosaErrorValidator.evWrongOutputList,
3815 TosaErrorValidator.evWrongRank,
3816 ),
3817 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003818 "reshape": {
3819 "op": Op.RESHAPE,
3820 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003821 "build_fcn": (
3822 build_reshape,
3823 TosaTensorGen.tgBasic,
3824 TosaTensorValuesGen.tvgDefault,
3825 TosaArgGen.agReshape,
3826 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003827 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003828 "error_if_validators": (
3829 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3830 TosaErrorValidator.evWrongInputType,
3831 TosaErrorValidator.evWrongOutputType,
3832 TosaErrorValidator.evWrongInputList,
3833 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00003834 TosaErrorValidator.evReshapeOutputSizeMultiInference,
3835 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003836 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003837 },
3838 "reverse": {
3839 "op": Op.REVERSE,
3840 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003841 "build_fcn": (
3842 build_reverse,
3843 TosaTensorGen.tgBasic,
3844 TosaTensorValuesGen.tvgDefault,
3845 TosaArgGen.agAxis,
3846 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003847 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003848 "error_if_validators": (
3849 TosaErrorValidator.evAxisSmallerZero,
3850 TosaErrorValidator.evAxisLargerRank,
3851 TosaErrorValidator.evWrongInputType,
3852 TosaErrorValidator.evWrongOutputType,
3853 TosaErrorValidator.evWrongInputList,
3854 TosaErrorValidator.evWrongOutputList,
3855 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003856 },
3857 "slice": {
3858 "op": Op.SLICE,
3859 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003860 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003861 "build_fcn": (
3862 build_slice,
3863 TosaTensorGen.tgBasic,
3864 TosaTensorValuesGen.tvgDefault,
3865 TosaArgGen.agSlice,
3866 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003867 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003868 "error_if_validators": (
3869 TosaErrorValidator.evStartSmallerZero,
3870 TosaErrorValidator.evSizeSmallerEqualZero,
3871 TosaErrorValidator.evStartSizeOutsideBounds,
3872 TosaErrorValidator.evSizeOutputShapeMismatch,
3873 TosaErrorValidator.evInputSizeStartLengthMismatch,
3874 TosaErrorValidator.evWrongRank,
3875 TosaErrorValidator.evWrongInputType,
3876 TosaErrorValidator.evWrongOutputType,
3877 TosaErrorValidator.evWrongInputList,
3878 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003879 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003880 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003881 },
3882 "tile": {
3883 "op": Op.TILE,
3884 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003885 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003886 "build_fcn": (
3887 build_tile,
3888 TosaTensorGen.tgBasic,
3889 TosaTensorValuesGen.tvgDefault,
3890 TosaArgGen.agTile,
3891 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003892 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003893 "error_if_validators": (
3894 TosaErrorValidator.evWrongInputType,
3895 TosaErrorValidator.evWrongOutputType,
3896 TosaErrorValidator.evWrongInputList,
3897 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003898 TosaErrorValidator.evRankMismatch,
3899 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003900 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003901 },
3902 "transpose": {
3903 "op": Op.TRANSPOSE,
3904 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003905 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003906 "build_fcn": (
3907 build_transpose,
3908 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003909 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003910 TosaArgGen.agTranspose,
3911 ),
3912 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003913 "error_if_validators": (
3914 TosaErrorValidator.evIndexOutsideBounds,
3915 TosaErrorValidator.evIndexUsedTwice,
3916 TosaErrorValidator.evWrongInputType,
3917 TosaErrorValidator.evWrongOutputType,
3918 TosaErrorValidator.evWrongInputList,
3919 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003920 TosaErrorValidator.evWrongRank,
3921 TosaErrorValidator.evRankMismatch,
3922 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003923 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003924 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003925 # Data nodes
3926 "const": {
3927 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003928 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 "build_fcn": (
3930 build_const,
3931 TosaTensorGen.tgBasic,
3932 TosaTensorValuesGen.tvgDefault,
3933 None,
3934 ),
Luke Hutton65872422023-02-20 10:33:04 +00003935 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08003936 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003937 "identity": {
3938 "op": Op.IDENTITY,
3939 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003940 "build_fcn": (
3941 build_unary,
3942 TosaTensorGen.tgBasic,
3943 TosaTensorValuesGen.tvgDefault,
3944 None,
3945 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 "types": TYPE_FIB,
3947 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003948 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003949 "gather": {
3950 "op": Op.GATHER,
3951 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3952 "operands": (1, 0),
3953 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003954 "build_fcn": (
3955 build_gather,
3956 TosaTensorGen.tgBasic,
3957 TosaTensorValuesGen.tvgDefault,
3958 None,
3959 ),
James Ward24dbc422022-10-19 12:20:31 +01003960 "types": (
3961 DType.INT8,
3962 DType.INT16,
3963 DType.INT32,
3964 DType.FP16,
3965 DType.BF16,
3966 DType.FP32,
3967 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003968 "error_if_validators": (
3969 TosaErrorValidator.evWrongInputType,
3970 TosaErrorValidator.evWrongOutputType,
3971 TosaErrorValidator.evWrongInputList,
3972 TosaErrorValidator.evWrongOutputList,
3973 TosaErrorValidator.evWrongRank,
3974 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003975 },
3976 "scatter": {
3977 "op": Op.SCATTER,
3978 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003979 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003980 "operands": (2, 0),
3981 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003982 "build_fcn": (
3983 build_scatter,
3984 TosaTensorGen.tgScatter,
3985 TosaTensorValuesGen.tvgDefault,
3986 None,
3987 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003988 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003989 "error_if_validators": (
3990 TosaErrorValidator.evWrongInputType,
3991 TosaErrorValidator.evWrongOutputType,
3992 TosaErrorValidator.evWrongInputList,
3993 TosaErrorValidator.evWrongOutputList,
3994 TosaErrorValidator.evWrongRank,
3995 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003996 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003997 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003998 "resize": {
3999 "op": Op.RESIZE,
4000 "operands": (1, 0),
4001 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004002 "build_fcn": (
4003 build_resize,
4004 TosaTensorGen.tgNHWC,
4005 TosaTensorValuesGen.tvgDefault,
4006 TosaArgGen.agResize,
4007 ),
James Ward24dbc422022-10-19 12:20:31 +01004008 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004009 "invalid_test_validators": (
4010 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004011 ),
4012 "error_if_validators": (
4013 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004014 TosaErrorValidator.evScaleSmallerEqualZero,
4015 TosaErrorValidator.evScaleNLargerMax,
4016 TosaErrorValidator.evScaleDLargerMax,
4017 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004018 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004019 TosaErrorValidator.evBorderSmallerMin,
4020 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004021 TosaErrorValidator.evWrongInputType,
4022 TosaErrorValidator.evWrongOutputType,
4023 TosaErrorValidator.evWrongRank,
4024 TosaErrorValidator.evWrongInputList,
4025 TosaErrorValidator.evWrongOutputList,
4026 TosaErrorValidator.evBatchMismatch,
4027 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004028 TosaErrorValidator.evResizeOutputShapeMismatch,
4029 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004030 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004031 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004032 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004033 "cast": {
4034 "op": Op.CAST,
4035 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004036 "build_fcn": (
4037 build_cast,
4038 TosaTensorGen.tgBasic,
4039 TosaTensorValuesGen.tvgDefault,
4040 TosaArgGen.agCast,
4041 ),
James Ward8b390432022-08-12 20:48:56 +01004042 "types": (
4043 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004044 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004045 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004046 DType.INT8,
4047 DType.INT16,
4048 DType.INT32,
4049 DType.BOOL,
4050 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004051 "error_if_validators": (
4052 TosaErrorValidator.evWrongInputType,
4053 TosaErrorValidator.evWrongOutputType,
4054 TosaErrorValidator.evWrongInputList,
4055 TosaErrorValidator.evWrongOutputList,
4056 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004057 },
4058 "rescale": {
4059 "op": Op.RESCALE,
4060 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004061 "build_fcn": (
4062 build_rescale,
4063 TosaTensorGen.tgBasic,
4064 TosaTensorValuesGen.tvgDefault,
4065 TosaArgGen.agRescale,
4066 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004067 "types": [
4068 DType.UINT8,
4069 DType.INT8,
4070 DType.INT16,
4071 DType.INT32,
4072 DType.INT48,
4073 DType.UINT16,
4074 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004075 "error_if_validators": (
4076 TosaErrorValidator.evInputZeroPointNotZero,
4077 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004078 TosaErrorValidator.evU16InputZeroPointNotValid,
4079 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004080 TosaErrorValidator.evScaleTrue,
4081 TosaErrorValidator.evScaleNotTrue,
4082 TosaErrorValidator.evWrongInputType,
4083 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004084 TosaErrorValidator.evWrongInputList,
4085 TosaErrorValidator.evWrongOutputList,
4086 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004087 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004088 # Custom
4089 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004090 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004091 # Two varients of cond_if, one that generates one of two constant tensors (no
4092 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4093 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004094 "cond_if_const": {
4095 "op": Op.COND_IF,
4096 "operands": (0, 2),
4097 "build_fcn": (
4098 build_cond_if_const,
4099 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004100 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004101 TosaArgGen.agCondIf,
4102 ),
4103 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004104 "error_if_validators": (
4105 TosaErrorValidator.evOutputListThenGraphMismatch,
4106 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004107 TosaErrorValidator.evCondIfCondNotMatchingBool,
4108 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004109 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004110 },
4111 "cond_if_binary": {
4112 "op": Op.COND_IF,
4113 "operands": (2, 0),
4114 "build_fcn": (
4115 build_cond_if_binary,
4116 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004117 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004118 TosaArgGen.agCondIf,
4119 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004120 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004121 "error_if_validators": (
4122 TosaErrorValidator.evInputListThenGraphMismatch,
4123 TosaErrorValidator.evInputListElseGraphMismatch,
4124 TosaErrorValidator.evOutputListThenGraphMismatch,
4125 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004126 TosaErrorValidator.evCondIfCondNotMatchingBool,
4127 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004128 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004129 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004130 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004131 "while_loop": {
4132 "op": Op.WHILE_LOOP,
4133 "operands": (0, 1),
4134 "build_fcn": (
4135 build_while_loop,
4136 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004137 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004138 TosaArgGen.agWhileLoop,
4139 ),
4140 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004141 "error_if_validators": (
4142 TosaErrorValidator.evInputListOutputListMismatch,
4143 TosaErrorValidator.evInputListCondGraphMismatch,
4144 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4145 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4146 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004147 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004148 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004149 },
Luke Hutton57287132023-02-06 14:54:18 +00004150 "fft2d": {
4151 "op": Op.FFT2D,
4152 "operands": (2, 0),
4153 "rank": (3, 3),
4154 "build_fcn": (
4155 build_fft2d,
4156 TosaTensorGen.tgFFT2d,
4157 TosaTensorValuesGen.tvgDefault,
4158 TosaArgGen.agFFT2d,
4159 ),
4160 "types": [DType.FP32],
4161 "error_if_validators": (
4162 TosaErrorValidator.evWrongInputType,
4163 TosaErrorValidator.evWrongOutputType,
4164 TosaErrorValidator.evWrongInputList,
4165 TosaErrorValidator.evWrongOutputList,
4166 TosaErrorValidator.evWrongRank,
4167 TosaErrorValidator.evBatchMismatch,
4168 TosaErrorValidator.evKernelNotPowerOfTwo,
4169 TosaErrorValidator.evFFTInputShapeMismatch,
4170 TosaErrorValidator.evFFTOutputShapeMismatch,
4171 ),
4172 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004173 "rfft2d": {
4174 "op": Op.RFFT2D,
4175 "operands": (1, 0),
4176 "rank": (3, 3),
4177 "build_fcn": (
4178 build_rfft2d,
4179 TosaTensorGen.tgRFFT2d,
4180 TosaTensorValuesGen.tvgDefault,
4181 TosaArgGen.agNone,
4182 ),
4183 "types": [DType.FP32],
4184 "error_if_validators": (
4185 TosaErrorValidator.evWrongInputType,
4186 TosaErrorValidator.evWrongOutputType,
4187 TosaErrorValidator.evWrongInputList,
4188 TosaErrorValidator.evWrongOutputList,
4189 TosaErrorValidator.evWrongRank,
4190 TosaErrorValidator.evBatchMismatch,
4191 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004192 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004193 ),
4194 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004195 }
4196
Kevin Cheng550ccc52021-03-03 11:21:43 -08004197
Eric Kunzee5e26762020-10-13 16:11:07 -07004198class OutputShaper:
4199 # Methods in this class compute the expected output shape and datatype
4200 # for common classes of operations
4201 def __init__(self):
4202 pass
4203
4204 # These methods return arguments that can be used for
4205 # creating a new output tensor
4206 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004207 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4208 if error_name != ErrorIf.RankMismatch:
4209 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004210 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004211
4212 shape = []
4213 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004214 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004215 shape.append(b.shape[i])
4216 else:
4217 shape.append(a.shape[i])
4218
Jerry Ge135c9552023-05-23 20:59:32 +00004219 fuzz_idx = rng.integers(0, len(a.shape))
4220 if error_name == ErrorIf.DimensionMismatch:
4221 shape[fuzz_idx] += 1
4222
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004223 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004224 all_dtypes = [
4225 DType.INT8,
4226 DType.INT16,
4227 DType.INT32,
4228 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004229 DType.FP16,
4230 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004231 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004232 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004233 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4234 outputDType = rng.choice(wrong_dtypes)
4235 else:
4236 outputDType = a.dtype
4237
4238 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004239
4240 @staticmethod
4241 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004242 assert len(a.shape) == len(b.shape)
4243 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004244
4245 shape = []
4246 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004247 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004248 shape.append(a.shape[i])
4249
Kevin Cheng550ccc52021-03-03 11:21:43 -08004250 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004251
4252 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004253 def unaryOp(ser, rng, a, error_name=None):
4254 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004255 all_dtypes = [
4256 DType.INT8,
4257 DType.INT16,
4258 DType.INT32,
4259 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004260 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004261 DType.FP16,
4262 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004263 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004264 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4265 outputDType = rng.choice(wrong_dtypes)
4266 else:
4267 outputDType = a.dtype
4268
4269 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004270
4271 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004272 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004273 if error_name != ErrorIf.RankMismatch:
4274 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004275 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004276
4277 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004278 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004279 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004280 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4281 else:
4282 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004283
Jerry Ge135c9552023-05-23 20:59:32 +00004284 fuzz_idx = rng.integers(0, len(a.shape))
4285 if error_name == ErrorIf.DimensionMismatch:
4286 shape[fuzz_idx] += 1
4287
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004288 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004289 all_dtypes = [
4290 DType.INT8,
4291 DType.INT16,
4292 DType.INT32,
4293 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004294 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004295 DType.FP16,
4296 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004297 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004298 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4299 outputDType = rng.choice(wrong_dtypes)
4300 else:
4301 outputDType = a.dtype
4302
4303 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004304
4305 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004306 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004307 if error_name != ErrorIf.RankMismatch:
4308 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004309 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004310
4311 # Do broadcast
4312 shape = []
4313 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004314 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004315 shape.append(b.shape[i])
4316 else:
4317 shape.append(a.shape[i])
4318
Jerry Ge135c9552023-05-23 20:59:32 +00004319 fuzz_idx = rng.integers(0, len(a.shape))
4320 if error_name == ErrorIf.DimensionMismatch:
4321 shape[fuzz_idx] += 1
4322
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004323 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004324 wrong_dtypes = [
4325 DType.INT8,
4326 DType.INT16,
4327 DType.INT32,
4328 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004329 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004330 DType.FP16,
4331 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004332 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004333 outputDType = rng.choice(wrong_dtypes)
4334 else:
4335 outputDType = DType.BOOL
4336
4337 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004338
4339 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004340 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004341 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004342 if error_name not in [
4343 ErrorIf.AxisSmallerZero,
4344 ErrorIf.AxisLargerRank,
4345 ErrorIf.ShapeOfAxisNotOne,
4346 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004347 shape[axis] = 1
4348 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4349 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004350
Matthew Haddond6ce7252021-09-29 15:35:44 +01004351 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004352 all_dtypes = [
4353 DType.INT8,
4354 DType.INT16,
4355 DType.INT32,
4356 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004357 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004358 DType.FP16,
4359 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004360 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004361 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4362 outputDType = rng.choice(wrong_dtypes)
4363 else:
4364 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004365
Matthew Haddond6ce7252021-09-29 15:35:44 +01004366 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004367
4368 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004369 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004370 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004371
4372 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4373 del shape[axis]
4374
4375 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4376 remove = rng.choice([True, False])
4377 if remove and len(shape) > 1:
4378 del shape[0]
4379 else:
4380 shape.append(1)
4381 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4382 for i in range(len(shape)):
4383 shape[i] = shape[i] + rng.integers(1, 10)
4384
4385 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004386 all_dtypes = [
4387 DType.INT8,
4388 DType.INT16,
4389 DType.INT32,
4390 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004391 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004392 DType.FP16,
4393 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004394 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004395 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4396 outputDType = rng.choice(wrong_dtypes)
4397 else:
4398 outputDType = DType.INT32
4399
4400 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004401
4402 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004403 def conv2dOp(
4404 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4405 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004406
4407 # IFM: NHWC
4408 # Filter: OHWI
4409 # OFM: NHWC
4410
Kevin Cheng550ccc52021-03-03 11:21:43 -08004411 h = (
4412 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004413 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004414 + padding[0]
4415 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004416 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004417 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004418
Kevin Cheng550ccc52021-03-03 11:21:43 -08004419 w = (
4420 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004421 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004422 + padding[2]
4423 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004424 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004425 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004426
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004427 if error_name == ErrorIf.ConvOutputShapeMismatch:
4428 choices = [1, 2, 3]
4429 change = rng.choice(choices)
4430 # increment in multiples of stride to not hit non-integer error case
4431 if change in [1, 3]:
4432 h = h + (rng.choice(choices) * strides[0])
4433 if change in [2, 3]:
4434 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004435
Eric Kunzee5e26762020-10-13 16:11:07 -07004436 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4437
James Ward8b390432022-08-12 20:48:56 +01004438 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004439 # Pick some potentially correct output dtype if input type is incorrect
4440 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004441 else:
James Ward8b390432022-08-12 20:48:56 +01004442 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004443
4444 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004445 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004446 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004447 else:
4448 excludes = [out_dtype]
4449 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004450 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004451
Kevin Cheng550ccc52021-03-03 11:21:43 -08004452 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004453
4454 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004455 def conv3dOp(
4456 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4457 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004458
4459 # IFM: NDHWC
4460 # Filter: ODHWI
4461 # OFM: NDHWC
4462
4463 d = (
4464 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004465 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004466 + padding[0]
4467 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004468 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004469 ) // strides[0] + 1
4470
4471 h = (
4472 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004473 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004474 + padding[2]
4475 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004476 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004477 ) // strides[1] + 1
4478
4479 w = (
4480 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004481 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004482 + padding[4]
4483 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004484 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004485 ) // strides[2] + 1
4486
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004487 if error_name == ErrorIf.ConvOutputShapeMismatch:
4488 choices = [1, 2, 3, 4]
4489 change = rng.choice(choices)
4490 # increment in multiples of stride to not hit non-integer error case
4491 if change in [1, 4]:
4492 d = d + (rng.choice(choices) * strides[0])
4493 if change in [2, 4]:
4494 h = h + (rng.choice(choices) * strides[1])
4495 if change in [3, 4]:
4496 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004497
Kevin Cheng1533b852021-09-01 12:51:58 -07004498 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4499
James Ward8b390432022-08-12 20:48:56 +01004500 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004501 # Pick some potentially correct output dtype if input type is incorrect
4502 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004503 else:
James Ward8b390432022-08-12 20:48:56 +01004504 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004505
4506 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004507 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004508 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004509 else:
4510 excludes = [out_dtype]
4511 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004512 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004513
4514 return ser.addOutput(ofm_shape, out_dtype)
4515
4516 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004517 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004518 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004519 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004520 # IFM: NHWC
4521 # Filter: HWCM
4522 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004523
Kevin Cheng550ccc52021-03-03 11:21:43 -08004524 h = (
4525 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004526 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004527 + padding[0]
4528 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004529 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004530 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004531
Kevin Cheng550ccc52021-03-03 11:21:43 -08004532 w = (
4533 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004534 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004535 + padding[2]
4536 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004537 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004538 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004539
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004540 if error_name == ErrorIf.ConvOutputShapeMismatch:
4541 choices = [1, 2, 3]
4542 change = rng.choice(choices)
4543 # increment in multiples of stride to not hit non-integer error case
4544 if change in [1, 3]:
4545 h = h + (rng.choice(choices) * strides[0])
4546 if change in [2, 3]:
4547 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004548
Eric Kunzee5e26762020-10-13 16:11:07 -07004549 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4550
James Ward8b390432022-08-12 20:48:56 +01004551 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004552 # Pick some potentially correct output dtype if input type is incorrect
4553 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004554 else:
James Ward8b390432022-08-12 20:48:56 +01004555 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004556
4557 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004558 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004559 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004560 else:
4561 excludes = [out_dtype]
4562 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004563 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004564
Kevin Cheng550ccc52021-03-03 11:21:43 -08004565 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004566
4567 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004568 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004569 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004570 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004571 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004572 h = 1
4573 w = 1
4574 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004575 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4576 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004577
4578 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004579 choices = [1, 2, 3]
4580 change = rng.choice(choices)
4581 # increment in multiples of stride to not hit non-integer error case
4582 if change in [1, 3]:
4583 h = h + (rng.choice(choices) * stride[0])
4584 if change in [2, 3]:
4585 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004586 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004587
4588 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004589 all_dtypes = [
4590 DType.INT8,
4591 DType.INT16,
4592 DType.INT32,
4593 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004594 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004595 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004596 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004597 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004598 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4599 outputDType = rng.choice(wrong_dtypes)
4600 else:
4601 outputDType = ifm.dtype
4602
4603 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004604
4605 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004606 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004607 # input: N, IC
4608 # filter: OC, IC
4609 # output: N, OC
4610
4611 output_shape = [input.shape[0], filter.shape[0]]
4612
James Ward8b390432022-08-12 20:48:56 +01004613 # Validated in arg_gen (also invalidated for ErrorIf)
4614 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004615
Kevin Cheng550ccc52021-03-03 11:21:43 -08004616 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004617
4618 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004619 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004620 # a: N, H, C
4621 # b: N, C, W
4622 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004623
Kevin Cheng2d60f002021-06-09 14:18:32 -07004624 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004625
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004626 if error_name == ErrorIf.WrongOutputType:
4627 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004628 incorrect_types = (
4629 DType.INT4,
4630 DType.INT8,
4631 DType.INT16,
4632 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004633 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004634 DType.FP16,
4635 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004636 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004637 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004638 incorrect_types = (
4639 DType.INT4,
4640 DType.INT8,
4641 DType.INT16,
4642 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004643 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004644 DType.FP16,
4645 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004646 )
James Ward24dbc422022-10-19 12:20:31 +01004647 elif (
4648 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4649 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004650 incorrect_types = (
4651 DType.INT4,
4652 DType.INT8,
4653 DType.INT16,
4654 DType.INT32,
4655 DType.INT48,
4656 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004657 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004658 elif error_name == ErrorIf.WrongInputType:
4659 # Pick some potentially correct output dtype if input type is incorrect
4660 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004661 else:
James Ward8b390432022-08-12 20:48:56 +01004662 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004663
Kevin Cheng550ccc52021-03-03 11:21:43 -08004664 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004665
4666 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004667 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004668 input1 = a[0]
4669 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004670
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004671 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004672 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004673 if not (
4674 # unable to concat tensors of different ranks
4675 error_name == ErrorIf.ConcatInputRankMismatch
4676 # unable to concat tensors along an invalid axis
4677 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004678 ):
4679 for tensor in remaining_inputs:
4680 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004681
Matthew Haddon01c359d2021-10-15 16:30:48 +01004682 if error_name == ErrorIf.ConcatShapeSumMismatch:
4683 output_shape[axis] += rng.integers(5, 10)
4684
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004685 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004686 all_dtypes = {
4687 DType.INT8,
4688 DType.INT16,
4689 DType.INT32,
4690 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004691 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004692 DType.FP16,
4693 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004694 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004695 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4696 outputDType = rng.choice(wrong_dtypes)
4697 else:
4698 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004699
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004700 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004701
4702 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004703 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004704
4705 output_shape = a.shape.copy()
4706
4707 for i in range(len(output_shape)):
4708 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4709
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004710 if error_name == ErrorIf.PadOutputShapeMismatch:
4711 bad_dim = rng.choice(range(len(output_shape)))
4712 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004713 elif error_name == ErrorIf.RankMismatch:
4714 output_shape = get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004715
Matthew Haddone807aae2021-10-11 18:12:58 +01004716 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004717 all_dtypes = [
4718 DType.INT8,
4719 DType.INT16,
4720 DType.INT32,
4721 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004722 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004723 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004724 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004725 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004726 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4727 outputDType = rng.choice(wrong_dtypes)
4728 else:
4729 outputDType = a.dtype
4730
4731 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004732
4733 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00004734 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00004735 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00004736
4737 if error_name == ErrorIf.WrongOutputType:
4738 all_dtypes = [
4739 DType.INT8,
4740 DType.INT16,
4741 DType.INT32,
4742 DType.INT48,
4743 DType.FP32,
4744 DType.FP16,
4745 DType.BF16,
4746 ]
4747 wrong_dtypes = list(set(all_dtypes))
4748 outputDType = rng.choice(wrong_dtypes)
4749 else:
4750 outputDType = DType.SHAPE
4751
4752 return ser.addOutput(output_shape, outputDType)
4753
4754 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004755 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004756 output_shape = shape.copy()
4757
Matthew Haddone807aae2021-10-11 18:12:58 +01004758 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4759 for i in range(len(output_shape)):
4760 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4761
4762 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004763 all_dtypes = [
4764 DType.INT8,
4765 DType.INT16,
4766 DType.INT32,
4767 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004768 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004769 DType.FP16,
4770 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004771 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004772 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4773 outputDType = rng.choice(wrong_dtypes)
4774 else:
4775 outputDType = a.dtype
4776
4777 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004778
4779 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004780 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004781
Matthew Haddone807aae2021-10-11 18:12:58 +01004782 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004783 all_dtypes = [
4784 DType.INT8,
4785 DType.INT16,
4786 DType.INT32,
4787 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004788 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004789 DType.FP16,
4790 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004791 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004792 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004793 outputDType = rng.choice(wrong_dtypes)
4794 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004795 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004796
Luke Huttona4e48ca2023-02-22 11:53:48 +00004797 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004798 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004799 for index in range(len(output_shape)):
4800 if output_shape[index] <= 2:
4801 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4802 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004803 output_shape[index] = output_shape[index] + rng.choice(
4804 [-2, -1, 1, 2]
4805 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004806 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4807 output_shape = input.shape.copy()
4808 elif error_name == ErrorIf.RankMismatch:
4809 output_shape = get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004810
4811 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004812
4813 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004814 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004815
4816 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004817 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004818
4819 for i in range(len(output_shape)):
4820 output_shape[i] = a.shape[i] * multiples[i]
4821
Luke Huttona4e48ca2023-02-22 11:53:48 +00004822 if error_name == ErrorIf.RankMismatch:
4823 output_shape = get_rank_mismatch_shape(rng, output_shape)
4824
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004825 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004826 all_dtypes = [
4827 DType.INT8,
4828 DType.INT16,
4829 DType.INT32,
4830 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004831 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004832 DType.FP16,
4833 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004834 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004835 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4836 outputDType = rng.choice(wrong_dtypes)
4837 else:
4838 outputDType = a.dtype
4839
4840 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004841
4842 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004843 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004844 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004845
Kevin Cheng550ccc52021-03-03 11:21:43 -08004846 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004847
Luke Huttona4e48ca2023-02-22 11:53:48 +00004848 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004849 for i in range(len(output_shape)):
4850 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004851
Luke Huttona4e48ca2023-02-22 11:53:48 +00004852 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4853 for i in range(len(output_shape)):
4854 output_shape[i] += rng.integers(1, 10)
4855 elif error_name == ErrorIf.RankMismatch:
4856 output_shape = get_rank_mismatch_shape(rng, output_shape)
4857
Matthew Haddone807aae2021-10-11 18:12:58 +01004858 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004859 all_dtypes = [
4860 DType.INT8,
4861 DType.INT16,
4862 DType.INT32,
4863 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004864 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004865 DType.FP16,
4866 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004867 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004868 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4869 outputDType = rng.choice(wrong_dtypes)
4870 else:
4871 outputDType = a.dtype
4872
4873 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004874
4875 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004876 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004877 if error_name != ErrorIf.WrongRank:
4878 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004879 assert len(indices.shape) == 2
4880 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004881
Kevin Cheng77d0f762020-11-24 10:26:32 -08004882 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4883
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004884 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004885 all_dtypes = [
4886 DType.INT8,
4887 DType.INT16,
4888 DType.INT32,
4889 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004890 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004891 DType.FP16,
4892 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004893 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004894 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4895 outputDType = rng.choice(wrong_dtypes)
4896 else:
4897 outputDType = values.dtype
4898
4899 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004900
4901 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004902 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004903 if error_name != ErrorIf.WrongRank:
4904 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004905 assert len(indices.shape) == 2
4906 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004907 assert values_in.shape[0] == indices.shape[0] # N
4908 assert input.shape[1] == indices.shape[1] # W
4909 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004910
4911 output_shape = values_in.shape
4912
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004913 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004914 all_dtypes = [
4915 DType.INT8,
4916 DType.INT16,
4917 DType.INT32,
4918 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004919 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004920 DType.FP16,
4921 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004922 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004923 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4924 outputDType = rng.choice(wrong_dtypes)
4925 else:
4926 outputDType = values_in.dtype
4927
4928 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004929
4930 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004931 def tableOp(ser, rng, input, error_name=None):
4932 # Same shape as the input, dtype dependent on input dtype
4933 if error_name != ErrorIf.WrongInputType:
4934 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004935 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004936 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004937 wrong_dtypes = [
4938 DType.INT8,
4939 DType.INT16,
4940 DType.INT32,
4941 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004942 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004943 DType.FP16,
4944 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004945 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004946 wrong_dtypes.remove(output_dtype)
4947 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004948 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004949
4950 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004951 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004952 serializer,
4953 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004954 input,
4955 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004956 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004957 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004958 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004959 input_dtype,
4960 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004961 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004962 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004963 # Calculate OH, OW
4964 scale_y_n = scale[0]
4965 scale_y_d = scale[1]
4966 scale_x_n = scale[2]
4967 scale_x_d = scale[3]
4968 if error_name == ErrorIf.ScaleSmallerEqualZero:
4969 scale_y_n = max(scale_y_n, 1)
4970 scale_y_d = max(scale_y_d, 1)
4971 scale_x_n = max(scale_x_n, 1)
4972 scale_x_d = max(scale_x_d, 1)
4973
4974 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4975 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4976
4977 if error_name is not None:
4978 # Make sure the output tensor is valid, which can occur when
4979 # scale, offset or border have been changed for ERROR_IFs
4980 oh = max(oh, 1)
4981 ow = max(ow, 1)
4982 if error_name != ErrorIf.MaxDimExceeded:
4983 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4984 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4985
4986 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4987 choices = [1, 2, 3]
4988 change = rng.choice(choices)
4989 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4990 if change in [1, 3]:
4991 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4992 oh -= scale_y_d
4993 assert oh > 0 # Should have been caught in agResize
4994 else:
4995 oh += scale_y_d
4996 if change in [2, 3]:
4997 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4998 ow -= scale_x_d
4999 assert ow > 0 # Should have been caught in agResize
5000 else:
5001 ow += scale_x_d
5002
Matthew Haddon848efb42021-09-09 12:30:53 +01005003 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005004 output_dims = [
5005 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005006 oh,
5007 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005008 input.shape[0],
5009 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005010 elif error_name == ErrorIf.BatchMismatch:
5011 output_dims = [
5012 input.shape[0] + rng.integers(1, 10),
5013 oh,
5014 ow,
5015 input.shape[3],
5016 ]
5017 elif error_name == ErrorIf.ChannelMismatch:
5018 output_dims = [
5019 input.shape[0],
5020 oh,
5021 ow,
5022 input.shape[3] + rng.integers(1, 10),
5023 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005024 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005025 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005026
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005027 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005028
5029 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005030 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005031 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005032
5033 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005034 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005035 if error_name == ErrorIf.ConvOutputShapeMismatch:
5036 choices = [1, 2, 3]
5037 change = rng.choice(choices)
5038 if change in [1, 3]:
5039 output_shape[1] = output_shape[1] + rng.choice(choices)
5040 if change in [2, 3]:
5041 output_shape[2] = output_shape[2] + rng.choice(choices)
5042
James Ward8b390432022-08-12 20:48:56 +01005043 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005044 # Pick some potentially correct output dtype if input type is incorrect
5045 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005046 else:
James Ward8b390432022-08-12 20:48:56 +01005047 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005048
5049 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005050 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005051 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005052 else:
5053 excludes = [out_dtype]
5054 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005055 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005056
Kevin Cheng550ccc52021-03-03 11:21:43 -08005057 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005058
5059 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005060 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5061 outputs = []
5062
5063 assert ifm1.dtype == ifm2.dtype
5064 input_dtype = ifm1.dtype
5065
5066 if error_name != ErrorIf.FFTInputShapeMismatch:
5067 assert ifm1.shape == ifm2.shape
5068
5069 input_shape = ifm1.shape
5070 if error_name != ErrorIf.WrongRank:
5071 assert len(input_shape) == 3
5072
5073 output_shape = input_shape.copy()
5074 output_dtype = input_dtype
5075
5076 if error_name == ErrorIf.WrongOutputType:
5077 excludes = [DType.FP32]
5078 wrong_dtypes = list(usableDTypes(excludes=excludes))
5079 output_dtype = rng.choice(wrong_dtypes)
5080 elif error_name == ErrorIf.BatchMismatch:
5081 output_shape[0] += rng.integers(1, 10)
5082 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5083 modify_dim = rng.choice([1, 2])
5084 output_shape[modify_dim] += rng.integers(1, 10)
5085
5086 outputs.append(serializer.addOutput(output_shape, output_dtype))
5087 outputs.append(serializer.addOutput(output_shape, output_dtype))
5088 return outputs
5089
5090 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005091 def rfft2dOp(serializer, rng, value, error_name=None):
5092 outputs = []
5093
5094 input_shape = value.shape
5095 if error_name != ErrorIf.WrongRank:
5096 assert len(input_shape) == 3
5097
5098 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5099
5100 output_dtype = value.dtype
5101 if error_name == ErrorIf.WrongOutputType:
5102 excludes = [DType.FP32]
5103 wrong_dtypes = list(usableDTypes(excludes=excludes))
5104 output_dtype = rng.choice(wrong_dtypes)
5105 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005106 output_shape[0] += rng.integers(1, 10)
5107 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5108 modify_dim = rng.choice([1, 2])
5109 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005110
5111 outputs.append(serializer.addOutput(output_shape, output_dtype))
5112 outputs.append(serializer.addOutput(output_shape, output_dtype))
5113 return outputs