blob: 239a64eb31400898f3585eb796793cc5347ab286 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnsona6185572021-06-21 15:55:35 +01003import itertools
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import math
5import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01006from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07007
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00008import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00009import serializer.tosa_serializer as ts
Jeremy Johnson2ec34942021-12-14 16:34:05 +000010from generator.tosa_error_if import ErrorIf
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011from serializer.tosa_serializer import DTypeNames
Les Bell0e027d42021-11-09 14:42:14 +000012from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
Eric Kunzee5e26762020-10-13 16:11:07 -070015
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000016# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Matthew Haddon630c17c2021-10-14 15:05:41 +010019
Les Bell0e027d42021-11-09 14:42:14 +000020def valueToName(item, value):
21 """Get the name of an attribute with the given value.
22
23 This convenience function is needed to print meaningful names for
24 the values of the tosa.Op.Op and tosa.DType.DType classes.
25 This would not be necessary if they were subclasses of Enum, or
26 IntEnum, which, sadly, they are not.
27
28 Args:
29 item: The class, or object, to find the value in
30 value: The value to find
31
32 Example, to get the name of a DType value:
33
34 name = valueToName(DType, DType.INT8) # returns 'INT8'
35 name = valueToName(DType, 4) # returns 'INT8'
36
37 Returns:
38 The name of the first attribute found with a matching value,
39
40 Raises:
41 ValueError if the value is not found
42 """
43 for attr in dir(item):
44 if getattr(item, attr) == value:
45 return attr
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000046 raise ValueError(f"value ({value}) not found")
47
Les Bell0e027d42021-11-09 14:42:14 +000048
49def allDTypes(*, excludes=None):
50 """Get a set of all DType values, optionally excluding some values.
51
52 This convenience function is needed to provide a sequence of DType values.
53 This would be much easier if DType was a subclass of Enum, or IntEnum,
54 as we could then iterate over the values directly, instead of using
55 dir() to find the attributes and then check if they are what we want.
56
57 Args:
58 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
59
60 Returns:
61 A set of DType values
62 """
63 excludes = () if not excludes else excludes
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000064 return {
65 getattr(DType, t)
66 for t in dir(DType)
67 if not callable(getattr(DType, t))
68 and not t.startswith("__")
69 and getattr(DType, t) not in excludes
70 }
71
Les Bell0e027d42021-11-09 14:42:14 +000072
73def usableDTypes(*, excludes=None):
74 """Get a set of usable DType values, optionally excluding some values.
75
76 Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
77 specified by the caller, as the serializer lib does not support them.
78 If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
79
80 Args:
81 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
82
83 Returns:
84 A set of DType values
85 """
86 omit = {DType.UNKNOWN, DType.UINT8}
87 omit.update(excludes if excludes else ())
88 return allDTypes(excludes=omit)
89
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000090
Matthew Haddon630c17c2021-10-14 15:05:41 +010091def product(shape):
92 value = 1
93 for n in shape:
94 value *= n
95 return value
96
Les Bell0e027d42021-11-09 14:42:14 +000097
Eric Kunzee5e26762020-10-13 16:11:07 -070098class TosaQuantGen:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000099 """QuantizedInfo random generator helper functions.
100
101 Specify with 'qgen': in the operator defintion.
102 """
Kevin Cheng550ccc52021-03-03 11:21:43 -0800103
Eric Kunzee5e26762020-10-13 16:11:07 -0700104 def __init__(self):
105 pass
106
107 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100108 def getQinfo(testGen, dtype, error_name=None):
109
Les Bell30e46802021-07-23 09:43:31 +0100110 if dtype == DType.INT8:
111 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100112 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +0100113 return testGen.randInt(0, 256)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000114 elif error_name in [
115 ErrorIf.InputZeroPointNotZero,
116 ErrorIf.WeightZeroPointNotZero,
117 ErrorIf.OutputZeroPointNotZero,
118 ]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100119 zero_point = testGen.randInt(-128, 128)
120 if zero_point == 0:
121 zero_point = 1
122 return zero_point
Les Bell30e46802021-07-23 09:43:31 +0100123 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
125 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100126 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100128 if error_name == ErrorIf.InputZeroPointNotZero:
129 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000130 TosaQuantGen.getQinfo(testGen, dtype, error_name),
131 TosaQuantGen.getQinfo(testGen, dtype),
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100132 )
133 elif error_name == ErrorIf.OutputZeroPointNotZero:
134 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000135 TosaQuantGen.getQinfo(testGen, dtype),
136 TosaQuantGen.getQinfo(testGen, dtype, error_name),
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100137 )
138 else:
139 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000140 TosaQuantGen.getQinfo(testGen, dtype),
141 TosaQuantGen.getQinfo(testGen, dtype),
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100142 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 return qinfo
144
145 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100146 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700147 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100148 if isinstance(dtype_or_dtypeList, list):
149 # a list of [input, weights, accumulator] dtypes
150 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 else:
Les Bell30e46802021-07-23 09:43:31 +0100152 # an int, [input, weights, accumulator] dtypes are the same
153 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100154
155 if error_name == ErrorIf.InputZeroPointNotZero:
156 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
157 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
158 elif error_name == ErrorIf.WeightZeroPointNotZero:
159 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
160 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
161 else:
162 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
163 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
164
Les Bell30e46802021-07-23 09:43:31 +0100165 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 return qinfo
167
168 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100169 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100171 if error_name == ErrorIf.InputZeroPointNotZero:
172 qinfo.MatMulQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000173 TosaQuantGen.getQinfo(testGen, dtype, error_name),
174 TosaQuantGen.getQinfo(testGen, dtype, error_name),
175 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100176 else:
177 qinfo.MatMulQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000178 TosaQuantGen.getQinfo(testGen, dtype),
179 TosaQuantGen.getQinfo(testGen, dtype),
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100180 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 return qinfo
182
183 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100184 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100186 if error_name == ErrorIf.InputZeroPointNotZero:
187 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
188 else:
189 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700190 return qinfo
191
192 @staticmethod
193 def computeMultiplierAndShift(scaleFp, scale32):
194 # Derived from computeMultiplierAndShiftTosaScale32
195 # Provide a floating-point scaling factor and the scale32 parameter
196 # to compute the multiplier and shift
197
198 if scale32:
199 scaleBits = 31
200 else:
201 scaleBits = 15
202
203 m, shift = math.frexp(scaleFp)
204
205 if scaleFp < 0.0:
206 m = -m
207
208 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800209 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
211 if multiplier == (1 << scaleBits):
212 multiplier = multiplier // 2
213 shift = shift + 1
214
215 shift = (-shift) + scaleBits
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000216 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
217 # scaleFp, scaleBits, m, multiplier, shift))
Matthew Haddonb724efc2021-08-25 16:40:29 +0100218
219 # Adjust multiplier such that shift is in allowed value range.
220 if shift == 0:
221 multiplier = multiplier // 4
222 shift = shift + 2
223 elif shift == 1:
224 multiplier = multiplier // 2
225 shift = shift + 1
226 elif shift == 63:
227 multiplier = multiplier * 2
228 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
Kevin Cheng550ccc52021-03-03 11:21:43 -0800230 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100231 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700232
233 return multiplier, shift
234
235
Kevin Cheng550ccc52021-03-03 11:21:43 -0800236class TosaTensorGen:
237 """Tensor generators create a shape list for the placeholder and const tensor
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000238 data operands for the operator.
239
240 The actual random data is generated separately for each test.
241 """
Kevin Cheng550ccc52021-03-03 11:21:43 -0800242
Eric Kunzee5e26762020-10-13 16:11:07 -0700243 def __init__(self):
244 pass
245
246 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100247 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800248 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700249 shape = testGen.makeShape(rank)
250
Matthew Haddon630c17c2021-10-14 15:05:41 +0100251 # Constrict the overall size of the shape when creating ERROR_IF tests
252 if error_name:
253 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Matthew Haddonc2025212021-10-08 21:21:05 +0100254
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 shape_list = []
256 for i in range(pl + const):
257 shape_list.append(shape.copy())
258
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100259 if error_name == ErrorIf.RankMismatch:
260 if rank == 1 and i != 1:
261 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
262 elif i != 1:
263 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
264
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 return shape_list
266
267 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100268 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800269 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700270
Matthew Haddon848efb42021-09-09 12:30:53 +0100271 if error_name != ErrorIf.WrongRank:
272 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700273
274 shape = testGen.makeShape(rank)
275
276 # Constrict the batch size?
277 if testGen.args.max_batch_size:
278 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100279
Matthew Haddon630c17c2021-10-14 15:05:41 +0100280 # Constrict the overall size of the shape when creating ERROR_IF tests
Jeremy Johnson27cf5432021-11-16 11:12:17 +0000281 if error_name and error_name != ErrorIf.MaxDimExceeded:
Matthew Haddon630c17c2021-10-14 15:05:41 +0100282 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Eric Kunzee5e26762020-10-13 16:11:07 -0700283
284 shape_list = []
285 for i in range(pl + const):
286 shape_list.append(shape.copy())
287
288 return shape_list
289
290 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100291 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800292 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800293
Kevin Cheng550ccc52021-03-03 11:21:43 -0800294 assert pl == 2
295 assert const == 0
Jeremy Johnson3ca02a72021-11-18 12:18:39 +0000296 if error_name != ErrorIf.WrongRank:
297 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800298
299 values_in_shape = testGen.makeShape(rank)
300
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100301 # ignore max batch size if target shape is set
302 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800303 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
304
Kevin Cheng550ccc52021-03-03 11:21:43 -0800305 W = testGen.randInt(
306 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
307 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100308 # Constrict W if one dimension is too large to keep tensor size reasonable
309 if max(values_in_shape) > 5000:
310 W = testGen.randInt(0, 16)
311
Kevin Cheng77d0f762020-11-24 10:26:32 -0800312 input_shape = [values_in_shape[0], W, values_in_shape[2]]
313
314 shape_list = []
315 shape_list.append(values_in_shape.copy())
316 shape_list.append(input_shape.copy())
317
318 return shape_list
319
320 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100321 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700322 shape = testGen.makeShape(rank)
323
Kevin Cheng550ccc52021-03-03 11:21:43 -0800324 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700325
326 shape_list = []
327
328 # Choose one of the inputs to broadcast
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000329 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000330 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Eric Kunzee5e26762020-10-13 16:11:07 -0700331 for i in range(pl + const):
332 shape_bcast = shape.copy()
333
334 # If the chosen input, pick a random index to broadcast
335 if i == bcast_idx:
336 fuzz_idx = testGen.randInt(0, rank)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000337 if error_name == ErrorIf.DimensionMismatch:
338 shape_bcast[fuzz_idx] += 1
339 elif error_name == ErrorIf.RankMismatch:
340 # Add one rank to the shape (or more for rank of 1)
341 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000342 shape_bcast = np.concatenate(
343 (shape_bcast, testGen.makeShape(extra_ranks))
344 )
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000345 if rank != 1:
346 # Either keep the extra rank, or remove it
347 new_len = testGen.rng.choice([-2, len(shape_bcast)])
348 shape_bcast = shape_bcast[:new_len]
349 else:
350 shape_bcast[fuzz_idx] = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700351
352 shape_list.append(shape_bcast)
353
354 return shape_list
355
356 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100357 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800358 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
Les Bell0e027d42021-11-09 14:42:14 +0000360 if error_name != ErrorIf.WrongRank:
361 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700362
363 # IFM dimensions are NHWC
364 ifm_shape = testGen.makeShape(rank)
365
366 # Constrict the batch size?
367 if testGen.args.max_batch_size:
368 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
369
Les Bell0e027d42021-11-09 14:42:14 +0000370 # Constrict the overall size of the shape when creating ERROR_IF tests
371 if error_name:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000372 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
373 ifm_shape, max_dim=24, max_items=10000
374 )
Les Bell0e027d42021-11-09 14:42:14 +0000375
Eric Kunzee5e26762020-10-13 16:11:07 -0700376 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800377 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700378
379 # Generate a random OFM depth
380 ofm_depth = testGen.makeShape(1)[0]
381
382 # The filter dimensions are OHWI
383 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
384
385 # The bias is OC
386 bias_shape = np.asarray([ofm_depth])
387
388 return [ifm_shape, filter_shape, bias_shape]
389
390 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100391 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700392 pl, const = op["operands"]
393
Les Bell0e027d42021-11-09 14:42:14 +0000394 if error_name != ErrorIf.WrongRank:
395 assert rank == 5
Kevin Cheng1533b852021-09-01 12:51:58 -0700396
397 # IFM dimensions are NDHWC
398 ifm_shape = testGen.makeShape(rank)
399
400 # Constrict the batch size?
401 if testGen.args.max_batch_size:
402 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
403
Les Bell0e027d42021-11-09 14:42:14 +0000404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
Les Bell0e027d42021-11-09 14:42:14 +0000409
Kevin Cheng1533b852021-09-01 12:51:58 -0700410 # Get the filter depth/height/width from the operator parameters
411 filter_dhw = op["filter"]
412
413 # Generate a random OFM channel
414 ofm_channel = testGen.makeShape(1)[0]
415
416 # The filter dimensions are ODHWI
417 filter_shape = np.asarray(
418 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
419 )
420
421 # The bias is OC
422 bias_shape = np.asarray([ofm_channel])
423
424 return [ifm_shape, filter_shape, bias_shape]
425
426 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100427 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700429
Les Bell0e027d42021-11-09 14:42:14 +0000430 if error_name != ErrorIf.WrongRank:
431 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700432
433 # IFM dimensions are NHWC
434 ifm_shape = testGen.makeShape(rank)
435
436 # Constrict the batch size?
437 if testGen.args.max_batch_size:
438 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
439
Les Bell0e027d42021-11-09 14:42:14 +0000440 # Constrict the overall size of the shape when creating ERROR_IF tests
441 if error_name:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000442 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
443 ifm_shape, max_dim=24, max_items=10000
444 )
Les Bell0e027d42021-11-09 14:42:14 +0000445
Eric Kunzee5e26762020-10-13 16:11:07 -0700446 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800447 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700448
449 # Generate a random OFM depth
450 ofm_depth = testGen.makeShape(1)[0]
451
452 # The filter dimensions are OHWI
453 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
454
Kevin Cheng989cb052021-04-28 16:29:44 -0700455 # The bias is OC
456 bias_shape = np.asarray([ofm_depth])
457
458 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700459
460 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100461 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800462 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700463
Les Bell0e027d42021-11-09 14:42:14 +0000464 if error_name != ErrorIf.WrongRank:
465 assert rank == 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800466 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700467
468 # IFM dimensions are NHWC
469 ifm_shape = testGen.makeShape(rank)
470
471 # Constrict the batch size?
472 if testGen.args.max_batch_size:
473 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
474
Les Bell0e027d42021-11-09 14:42:14 +0000475 # Constrict the overall size of the shape when creating ERROR_IF tests
476 if error_name:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000477 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
478 ifm_shape, max_dim=24, max_items=10000
479 )
Les Bell0e027d42021-11-09 14:42:14 +0000480
Eric Kunzee5e26762020-10-13 16:11:07 -0700481 # Get the filter height/width from the operator parameters
482 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800483 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700484
485 # Generate a random OFM depth, but don't let it get too big because
486 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800487 filter_m = (
488 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
489 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700490
491 # The filter dimensions are HWCM
492 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
493
494 # The bias is M * C
495 bias_shape = np.asarray([ifm_shape[3] * filter_m])
496
497 return [ifm_shape, filter_shape, bias_shape]
498
499 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100500 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800501 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700502
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100503 if error_name != ErrorIf.WrongRank:
504 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
506 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100507
Matthew Haddon630c17c2021-10-14 15:05:41 +0100508 # Constrict the overall size of the shape when creating ERROR_IF tests
509 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000510 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100511
Kevin Chengacb550f2021-06-29 15:32:19 -0700512 filter_oc = testGen.rng.integers(
513 low=testGen.args.tensor_shape_range[0],
514 high=testGen.args.tensor_shape_range[1],
515 size=1,
516 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 filter_shape = np.asarray([filter_oc, input_shape[1]])
518
519 bias_shape = np.asarray([filter_oc])
520
521 return [input_shape, filter_shape, bias_shape]
522
523 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100524 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800525 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527 if error_name != ErrorIf.WrongRank:
528 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800529 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700530
531 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100532
Matthew Haddon630c17c2021-10-14 15:05:41 +0100533 # Constrict the overall size of the shape when creating ERROR_IF tests
534 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000535 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100536
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100537 # Get a random number for b_oc even if target shape is defined
538 b_oc = np.int32(
539 testGen.rng.integers(
540 low=testGen.args.tensor_shape_range[0],
541 high=testGen.args.tensor_shape_range[1],
542 size=1,
543 )
544 )[0]
545 # If N or H is large let b_oc be 1 to reduce output tensor size
546 if max(a_shape) > 1000:
547 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100549 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700550 return [a_shape, b_shape]
551
Matthew Haddon818ab902021-07-27 09:12:49 +0100552 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100553 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100554 pl, const = opName["operands"]
555 shape = testGen.makeShape(rank)
556
557 # Create extra tensors to concat.
558 # Take into account value of pl when getting maximum number of concats
559 num_tensors = testGen.randInt(0, 4)
560 shape_list = []
561 for i in range(pl + const + num_tensors):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100562 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
563 remove = testGen.rng.choice([True, False])
564 wrongShape = shape.copy()
565
566 if remove and len(shape) > 1:
567 wrongShape = wrongShape[1:]
568 else:
569 wrongShape = list(wrongShape)
570 wrongShape.append(testGen.rng.integers(1, 10))
571
572 shape_list.append(wrongShape)
573 else:
574 shape_list.append(shape.copy())
Matthew Haddon818ab902021-07-27 09:12:49 +0100575
576 return shape_list
577
578 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100579 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000580 if error_name in [
581 ErrorIf.AxisSmallerZero,
582 ErrorIf.AxisLargerRank,
583 ErrorIf.ConcatInputRankMismatch,
584 ]:
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100585 return shapeList
586
Matthew Haddon818ab902021-07-27 09:12:49 +0100587 # Split concat shape along axis to allow for multiple const inputs
588 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100589 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100590 # If axis can't be split we still need to invalidate other dimensions
591 if error_name == ErrorIf.ConcatInputDimMismatch:
592 for shape in shapeList[1:]:
593 # Negative test shapeLists are created individually for each test,
594 # so no need to copy the shape before altering it.
595 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
Matthew Haddon818ab902021-07-27 09:12:49 +0100596 return shapeList
597
Jeremy Johnson960985a2021-10-06 10:58:14 +0100598 # Create copy of shape we are going to split (so we don't alter shapeList)
599 shape = shapeList[0].copy()
600 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100601 new_shapeList = [shape.copy()]
602 length_on_axis = shape[axis]
603 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700604 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100605 # Calculate split on axis and remaining value
606 split_shape_val = int(shape[axis] / 2)
607 remaining_length = remaining_length - split_shape_val
608
609 # Append new shape, and set remaining shape
610 shape[axis] = split_shape_val
611 new_shapeList.append(shape.copy())
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100612
613 # invalidate dimensions
614 if error_name == ErrorIf.ConcatInputDimMismatch:
615 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
616 else:
617 shape[axis] = remaining_length
618
Matthew Haddon818ab902021-07-27 09:12:49 +0100619 if i == len(shapeList) - 3:
620 new_shapeList.append(shape.copy())
621
622 return new_shapeList
623
624
Eric Kunzee5e26762020-10-13 16:11:07 -0700625class TosaArgGen:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 """Argument generators create exhaustive or random lists of attributes for
627 operators that take attributes or other parameters.
628
629 The return value is a list of (descriptive_name, [arglist]) tuples where
630 the descriptive_name is appended to the test name and the arglist is expanded
631 as arguments to the operator build function.
632 """
Kevin Cheng550ccc52021-03-03 11:21:43 -0800633
Eric Kunzee5e26762020-10-13 16:11:07 -0700634 def __init__(self):
635 pass
636
637 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100638 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800639 """A trivial argument generator for operators that don't take any
640 non-tensor arguments"""
641 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700642
643 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100644 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800645 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700646 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700647 shape = shapeList[0]
648
Matthew Haddond6ce7252021-09-29 15:35:44 +0100649 if error_name == ErrorIf.AxisSmallerZero:
650 small_axis = testGen.rng.integers(-5, 0)
651 axes.append(("axis{}".format(small_axis), [small_axis]))
652 elif error_name == ErrorIf.AxisLargerRank:
653 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
654 axes.append(("axis{}".format(large_axis), [large_axis]))
655 else:
656 for a in range(0, len(shape)):
657 axes.append(("axis{}".format(a), [a]))
658
Eric Kunzee5e26762020-10-13 16:11:07 -0700659 return axes
660
661 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100662 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700663 arg_list = []
664
665 ifm_shape = shapeList[0]
666 filter_shape = shapeList[1]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
Les Bell7aa69f42021-09-20 10:44:07 +0100668 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700669
Les Bell7aa69f42021-09-20 10:44:07 +0100670 # Check the rank
671 rank = 5 if opName.startswith("conv3d") else 4
Les Bell0e027d42021-11-09 14:42:14 +0000672 if error_name != ErrorIf.WrongRank:
673 assert len(ifm_shape) == rank
674 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700675
Les Bell7aa69f42021-09-20 10:44:07 +0100676 # kernel rank omits batch and channels
677 k_rank = rank - 2
Les Bell0e027d42021-11-09 14:42:14 +0000678 assert len(k) == k_rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700679
Les Bell7aa69f42021-09-20 10:44:07 +0100680 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000681 # - except for named errors, which use specific invalid value(s)
682 if error_name == ErrorIf.PadSmallerZero:
683 p_vals = [testGen.rng.choice(range(-5, 0))]
684 else:
685 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100686 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000687 if error_name == ErrorIf.StrideSmallerOne:
688 # Can't use stride=0, as it is used to derive output shape, as a divisor
689 s_vals = [testGen.rng.choice(range(-5, 0))]
690 else:
691 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100692 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
Les Bell0e027d42021-11-09 14:42:14 +0000693 if error_name == ErrorIf.DilationSmallerOne:
694 d_vals = [testGen.rng.choice(range(-5, 1))]
695 else:
696 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100697 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700698
Les Bell0e027d42021-11-09 14:42:14 +0000699 if not error_name:
700 # add some oversize argument values
701 if max(ifm_shape) < 64:
702 bigPadding = 9
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000703 paddings.update(
704 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
705 )
Les Bell0e027d42021-11-09 14:42:14 +0000706 bigStride = 8
707 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
708 bigDilation = 7
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000709 dilations.update(
710 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
711 )
Les Bellf414b3c2021-09-06 11:29:46 +0100712
Les Bell0e027d42021-11-09 14:42:14 +0000713 # There are too many parameter combinations, so generate them sparsely,
714 # very sparse for negative tests
715 sparsity_factor = 2 if error_name else 100
716 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
717 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100718 if sparsity < 13:
719 sparsity = 1
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000720 # To get a variety of parameter combinations sparsity should not be a
721 # multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100722 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
723 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000724
Les Bellf414b3c2021-09-06 11:29:46 +0100725 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100726 for s in sorted(list(strides)):
727 for p in sorted(list(paddings)):
728 for d in sorted(list(dilations)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000729 if (
730 n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100731 # padding must not exceed the kernel size ?
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000732 # and p[0] < k[0] and p[1] < k[0]
733 # and p[2] < k[1] and p[3] < k[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100734 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
735 # the padded shape must exceed the kernel size
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000736 and (ifm_shape[1] + p[0] + p[1]) > k[0]
737 and (ifm_shape[2] + p[2] + p[3]) > k[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100738 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
739 # the padded shape must exceed the dilation
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000740 and (ifm_shape[1] + p[0] + p[1]) > d[0]
741 and (ifm_shape[2] + p[2] + p[3]) > d[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100742 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
743 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100744 arg_list.append(
745 (
746 "st{}_pad{}_dilat{}".format(
747 "".join([str(x) for x in s]),
748 "".join([str(x) for x in p]),
749 "".join([str(x) for x in d]),
750 ),
751 [s, p, d],
752 )
753 )
754 n += 1
755
Kevin Cheng1533b852021-09-01 12:51:58 -0700756 return arg_list
757
758 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100759 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 arg_list = []
761
762 ifm_shape = shapeList[0]
763 filter_shape = shapeList[1]
764
765 # Must be rank 4
Les Bell0e027d42021-11-09 14:42:14 +0000766 if error_name != ErrorIf.WrongRank:
767 assert len(ifm_shape) == 4
768 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700769
Les Bell7aa69f42021-09-20 10:44:07 +0100770 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000771 # - except for named errors, which use specific invalid value(s)
772 if error_name == ErrorIf.PadSmallerZero:
773 p_vals = [testGen.rng.choice(range(-5, 0))]
774 else:
775 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100776 paddings = {x for x in itertools.product(*([p_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000777 if error_name == ErrorIf.StrideSmallerOne:
778 # Can't use stride=0, as it is used to derive output shape, as a divisor
779 s_vals = [testGen.rng.choice(range(-5, 0))]
780 else:
781 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100782 strides = {x for x in itertools.product(*([s_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000783 if error_name == ErrorIf.DilationSmallerOne:
784 d_vals = [testGen.rng.choice(range(-5, 1))]
785 else:
786 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100787 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700788
Les Bell0e027d42021-11-09 14:42:14 +0000789 if not error_name:
790 # add some oversize argument values
791 if max(ifm_shape) < 64:
792 bigPadding = 9
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 paddings.update(
794 {x for x in itertools.product(*([[0, bigPadding]] * 2))}
795 )
Les Bell0e027d42021-11-09 14:42:14 +0000796 bigStride = 8
797 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
798 bigDilation = 7
799 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700800
Les Bell0e027d42021-11-09 14:42:14 +0000801 # There are too many parameter combinations, so generate them sparsely,
802 # very sparse for negative tests
803 sparsity_factor = 2 if error_name else 100
804 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
805 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100806 if sparsity < 13:
807 sparsity = 1
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000808 # To get a variety of parameter combinations sparsity should not be a
809 # multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100810 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
811 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000812
Les Bell7aa69f42021-09-20 10:44:07 +0100813 n = 0
814 for s in sorted(list(strides)):
815 for p in sorted(list(paddings)):
816 for d in sorted(list(dilations)):
817 if n % sparsity == 0:
818 # Determine the output shape
819 oh = (
820 ifm_shape[1]
821 - filter_shape[1]
822 - (filter_shape[1] - 1) * (d[0] - 1)
823 + 2 * p[0]
824 ) // s[0] + 1
825 ow = (
826 ifm_shape[2]
827 - filter_shape[2]
828 - (filter_shape[2] - 1) * (d[1] - 1)
829 + 2 * p[1]
830 ) // s[1] + 1
831 os = [ifm_shape[0], oh, ow, filter_shape[0]]
832 arg_list.append(
833 (
834 "st{}_pad{}_dilat{}_os{}".format(
835 "".join([str(x) for x in s]),
836 "".join([str(x) for x in p]),
837 "".join([str(x) for x in d]),
838 "x".join([str(x) for x in os]),
839 ),
840 [s, p, d, os],
841 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800842 )
Les Bell7aa69f42021-09-20 10:44:07 +0100843 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700844
845 return arg_list
846
847 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100848 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700849 arg_list = []
850 rank = len(shapeList[0])
851
Les Bell7ffccce2021-07-28 15:37:02 +0100852 # Exhaustively test combinations of padding on each side of each dimension
853 # - the range of padding values is defined by pad_min and pad_max
854 # - for padding >9, the name format needs to be more distinctive
855 pad_min, pad_max = 0, 1
856 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100857 if error_name == ErrorIf.PadSmallerZero:
858 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100859 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
860 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700861
Kevin Chengfe392ce2021-10-18 21:51:55 +0000862 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
863 pad_const_int = testGen.getRandNumberDType(dtype)
864 pad_const_fp = 0
865 elif dtype == DType.FLOAT:
866 pad_const_int = 0
867 pad_const_fp = testGen.getRandNumberDType(dtype)
868 else:
869 return []
870
Les Bell7ffccce2021-07-28 15:37:02 +0100871 for paddings in shape_pad_values:
872 name = "pad"
873 for r in range(rank):
874 before, after = paddings[r]
875 name = f"{name}{before}{after}"
Kevin Chengfe392ce2021-10-18 21:51:55 +0000876 arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700877
878 return arg_list
879
880 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100881 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700882 arg_list = []
883
884 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100885 if error_name != ErrorIf.WrongRank:
886 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700887
Les Bell7aa69f42021-09-20 10:44:07 +0100888 # Generate comprehensive argument lists
889 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
890 paddings = {x for x in itertools.product(*([p_vals] * 4))}
891 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
892 strides = {x for x in itertools.product(*([s_vals] * 2))}
893 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
894 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700895
Les Bell7aa69f42021-09-20 10:44:07 +0100896 # add some oversize argument values
897 bigStride = 7
898 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
899 bigKernel = 6
900 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
901 if max(shape) < 64:
902 # padding must be less than the kernel size
903 bigPadding = bigKernel - 1
904 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700905
Les Bell0e027d42021-11-09 14:42:14 +0000906 # There are too many parameter combinations, so generate them sparsely,
907 # very sparse for negative tests
908 sparsity_factor = 2 if error_name else 500
909 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
910
Les Bell7aa69f42021-09-20 10:44:07 +0100911 n = 0
912 for s in sorted(list(strides)):
913 for p in sorted(list(paddings)):
914 for k in sorted(list(kernels)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000915 if error_name in [
916 ErrorIf.StrideSmallerOne,
917 ErrorIf.KernelSmallerOne,
918 ErrorIf.PadSmallerZero,
919 ErrorIf.PadLargerEqualKernel,
920 ]:
921 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
922 testGen, error_name, s, p, k
923 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100924 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
925 arg_list.append(
926 (
927 "st{}_kern{}_pad{}".format(
928 "".join([str(x) for x in sNew]),
929 "".join([str(x) for x in kNew]),
930 "".join([str(x) for x in pNew]),
931 ),
932 [sNew, pNew, kNew],
933 )
934 )
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000935 elif (
936 n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100937 # padding must not exceed the kernel size
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000938 and p[0] < k[0]
939 and p[1] < k[0]
940 and p[2] < k[1]
941 and p[3] < k[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100942 # the padded shape must exceed the kernel size
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000943 and (shape[1] + p[0] + p[1]) > k[0]
944 and (shape[2] + p[2] + p[3]) > k[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100945 ):
946 arg_list.append(
947 (
948 "st{}_kern{}_pad{}".format(
949 "".join([str(x) for x in s]),
950 "".join([str(x) for x in k]),
951 "".join([str(x) for x in p]),
952 ),
953 [s, p, k],
954 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800955 )
Les Bell7aa69f42021-09-20 10:44:07 +0100956 n += 1
957
Eric Kunzee5e26762020-10-13 16:11:07 -0700958 return arg_list
959
960 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100961 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700962 arg_list = []
963
964 # Enumerate the output types here
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100965 if error_name == ErrorIf.WrongOutputType:
966 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
967 elif inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800968 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700969 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800970 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700971 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800972 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700973 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800974 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700975 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800976 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100977 elif error_name == ErrorIf.WrongInputType:
978 # Pick some potentially correct output type for incorrect input type
979 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700980 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800981 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700982
983 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800984 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700985
986 return arg_list
987
988 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100989 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700990 arg_list = []
991
992 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100993 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000994 if (
995 dtype in [DType.UINT8, DType.INT8]
996 and error_name == ErrorIf.OutputZeroPointNotZero
997 ):
Matthew Haddonc2025212021-10-08 21:21:05 +0100998 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000999 if (
1000 inDtype == DType.UINT8
1001 and dtype != DType.INT8
1002 and error_name != ErrorIf.WrongOutputType
1003 ):
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001004 # The only output dtype for UINT8 is INT8, skip all other combinations
1005 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001006 if (
1007 inDtype != DType.INT8
1008 and dtype == DType.UINT8
1009 and error_name != ErrorIf.WrongOutputType
1010 ):
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001011 # The only input dtype for UINT8 is INT8, skip all other combinations
1012 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001013 if (
1014 error_name == ErrorIf.WrongOutputType
1015 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype)
1016 ):
Matthew Haddonc2025212021-10-08 21:21:05 +01001017 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001018
Kevin Cheng550ccc52021-03-03 11:21:43 -08001019 for scale32 in [False, True]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001020 if error_name == ErrorIf.ScaleTrue and not scale32:
Matthew Haddonc2025212021-10-08 21:21:05 +01001021 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001022 elif error_name == ErrorIf.ScaleNotTrue and scale32:
Matthew Haddonc2025212021-10-08 21:21:05 +01001023 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -08001024 for double_round in [False, True]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001025 if error_name == ErrorIf.ScaleNotTrue and not double_round:
Matthew Haddonc2025212021-10-08 21:21:05 +01001026 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -08001027 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001028
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 if (
1030 inDtype == DType.INT48
1031 and scale32
1032 and error_name != ErrorIf.ScaleTrue
1033 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001034 # Illegal condition. Must be scale32=False
1035 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001036 if (
1037 double_round
1038 and not scale32
1039 and error_name != ErrorIf.ScaleNotTrue
1040 ):
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001041 # Illegal condition. ERROR_IF(!scale32 && double_round)
1042 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001043
Kevin Cheng550ccc52021-03-03 11:21:43 -08001044 arg_list.append(
1045 (
1046 "out{}_sc{}_dr{}_pc{}".format(
1047 DTypeNames[dtype],
1048 int(scale32),
1049 int(double_round),
1050 int(per_channel),
1051 ),
1052 [dtype, scale32, double_round, per_channel],
1053 )
1054 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001055
1056 return arg_list
1057
Kevin Chengaee1fac2020-11-11 13:54:06 -08001058 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001059 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001060 arg_list = []
1061
1062 if dtype is DType.INT32:
1063 for p in range(testGen.args.num_rand_permutations):
1064
1065 shift = testGen.randInt(0, 32)
1066
Kevin Cheng550ccc52021-03-03 11:21:43 -08001067 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001068 else:
Matthew Haddon43e37192021-07-09 14:13:02 +01001069 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001070
1071 return arg_list
1072
1073 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001074 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001075 arg_list = []
1076
Kevin Cheng550ccc52021-03-03 11:21:43 -08001077 arg_list.append(("roundTrue", [True]))
1078 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001079
1080 return arg_list
1081
Eric Kunzee5e26762020-10-13 16:11:07 -07001082 # Helper function for reshape. Gets some factors of a larger number.
1083 @staticmethod
1084 def getFactors(val, start=1):
1085 factors = []
1086
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001087 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -07001088 if (val % i) == 0:
1089 factors.append(i)
1090
1091 return factors
1092
1093 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001094 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001095 arg_list = []
1096
1097 origShape = shapeList[0]
1098
1099 totalElements = 1
1100 for s in origShape:
1101 totalElements *= s
1102
1103 # This code is NOT fast. Fortunately, the numbers are fairly small.
1104 factors = TosaArgGen.getFactors(totalElements)
1105
1106 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001107 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001108 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -07001109 continue
1110
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001111 found = True
1112 # escape_counter breaks while loop if it continues on for too long
1113 escape_counter = 0
1114 while found:
1115 newShape = []
1116 # Generate newShape ensuring it isn't a duplicate
1117 remainingElements = totalElements
1118 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001119 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001120 # pick rank-1 factors
1121 newShape.append(shuffledFactors[0])
1122 remainingElements = remainingElements // shuffledFactors[0]
1123 shuffledFactors = testGen.rng.permutation(
1124 TosaArgGen.getFactors(remainingElements)
1125 )
1126 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -07001127
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001128 # Toss in a -1 sometimes
1129 minusOne = testGen.randInt(0, newRank * 4)
1130 if minusOne < newRank:
1131 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -07001132
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001133 # Check for duplicates
1134 found = False
1135 for name, other_shape in arg_list:
1136 if other_shape[0] == newShape:
1137 found = True
1138 break
1139
1140 escape_counter += 1
1141 if escape_counter >= 100:
1142 break
1143
1144 if not found:
1145 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001146
1147 return arg_list
1148
Eric Kunzee5e26762020-10-13 16:11:07 -07001149 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001150 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001151 arg_list = []
1152
1153 ifm_shape = shapeList[0]
1154
Matthew Haddone807aae2021-10-11 18:12:58 +01001155 if error_name == ErrorIf.IndexOutsideBounds:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001156 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
Matthew Haddone807aae2021-10-11 18:12:58 +01001157 incorrect_small_index = range(-len(ifm_shape), 0)
1158 permutations = [p for p in itertools.permutations(incorrect_large_index)]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001159 permutations.extend(
1160 [p for p in itertools.permutations(incorrect_small_index)]
1161 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001162 elif error_name == ErrorIf.IndexUsedTwice:
1163 # Create list with a duplicated index
1164 perm_range = list(range(len(ifm_shape)))
1165 index_choice = testGen.rng.choice(range(len(perm_range)))
1166 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1167 permutations = [p for p in itertools.permutations(perm_range)]
1168
Matthew Haddone807aae2021-10-11 18:12:58 +01001169 else:
1170 # Get all permutations
1171 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
Jeremy Johnsona6185572021-06-21 15:55:35 +01001173 # Limit to possible permutations from shape dimension or argument setting
1174 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001175
Jeremy Johnsona6185572021-06-21 15:55:35 +01001176 # Get random permutation generator that uses all permutations
1177 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001178
Jeremy Johnsona6185572021-06-21 15:55:35 +01001179 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -07001180 arg_list = [
1181 ("perm{}".format(p), [random_permutations[p].tolist()])
1182 for p in range(limit)
1183 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07001184 return arg_list
1185
1186 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001187 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001188 arg_list = []
1189
1190 ifm_shape = shapeList[0]
1191 rank = len(ifm_shape)
1192
1193 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +01001194 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -07001195 size = []
1196
Kevin Cheng550ccc52021-03-03 11:21:43 -08001197 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -07001198
1199 for i in range(rank):
1200 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +01001201 start.append(testGen.randInt(0, ifm_shape[i]))
1202 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001203
1204 # Invalid slice size?
1205 if size[i] == 0:
1206 valid = False
1207 else:
Matthew Haddone807aae2021-10-11 18:12:58 +01001208 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -07001209 size.append(1)
1210
1211 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +01001212 # If ERROR_IF test required then incorrect start, size will be returned
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001213 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1214 testGen, error_name, ifm_shape, start, size
1215 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001216 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001217 return arg_list
1218
1219 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001220 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001221 arg_list = []
1222
1223 ifm_shape = shapeList[0]
1224 rank = len(ifm_shape)
1225
1226 for p in range(testGen.args.num_rand_permutations):
1227
1228 # Pick a few random, but small multiple values
1229 # because otherwise this has a tendency to generate
1230 # enormous tensors
1231 multiples = []
1232 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001233 if ifm_shape[i] > 1000:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001234 # Multiple of 1 if ifm_shape dimension is large to reduce
1235 # tensor size
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001236 multiples.append(1)
1237 elif max(ifm_shape) > 1000:
1238 multiples.append(2)
1239 else:
1240 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001241 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001242
1243 return arg_list
1244
1245 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001246 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001247 arg_list = []
1248
1249 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001250 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001251
1252 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001253 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001254 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001255 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001256 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001257 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001258 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001259 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001260 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001261 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001262 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001263 elif error_name == ErrorIf.WrongInputType:
1264 # If an incorrect input type is used then we set a 'correct'
1265 # output type to avoid other errors
1266 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001267 else:
1268 continue
1269
1270 for outputDType in outputDTypeList:
1271 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001272 # Randomly generate legal output dimensions and shift
1273 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001274 # A output_dim of 1 will cause offset to exceed allowed range
1275 # so minimum value 2 produced below
1276 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001277 while (float(ifm_shape[1]) / float(output_dims[0])) >= 16:
Matthew Haddone86fd342021-09-07 16:12:21 +01001278 output_dims[0] += 1
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001279 while (float(ifm_shape[2]) / float(output_dims[1])) >= 16:
Matthew Haddone86fd342021-09-07 16:12:21 +01001280 output_dims[1] += 1
1281
Kevin Cheng77d0f762020-11-24 10:26:32 -08001282 in_center_h = (ifm_shape[1] - 1) / 2.0
1283 in_center_w = (ifm_shape[2] - 1) / 2.0
1284 out_center_h = (output_dims[0] - 1) / 2.0
1285 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001286
Kevin Cheng77d0f762020-11-24 10:26:32 -08001287 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1288 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1289 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1290 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001291
Kevin Cheng77d0f762020-11-24 10:26:32 -08001292 if outputDType == DType.FLOAT:
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001293 float_op = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001294 arg_str = (
1295 "mode{}_shift{}_odim{}x{}_out{}"
1296 "_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}"
1297 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001298 shift = 0
1299 stride = [0, 0]
1300 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001301 stride_fp = [fp_stride_y, fp_stride_x]
1302 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001303
Kevin Cheng77d0f762020-11-24 10:26:32 -08001304 else:
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001305 float_op = False
1306 arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}"
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001307 shift = testGen.randInt(1, 12)
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001308 # Now search for a shift value (1 to 11) that will produce
1309 # a valid and predictable resize operation
1310 count = 0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001311 while count < 12:
Kevin Cheng77d0f762020-11-24 10:26:32 -08001312 unit = float(1 << shift)
1313 stride_y = int(round(fp_stride_y * unit))
1314 stride_x = int(round(fp_stride_x * unit))
1315 offset_y = int(round(fp_offset_y * unit))
1316 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001317
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001318 if (
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001319 stride_y <= 0
1320 or stride_x <= 0
1321 or stride_y >= (16 << shift)
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001322 or stride_x >= (16 << shift)
1323 or offset_y >= (16 << shift)
1324 or offset_x >= (16 << shift)
1325 or offset_y <= (-16 << shift)
1326 or offset_x <= (-16 << shift)
1327 ):
1328 # Change the shift value and check again
1329 count += 1
1330 shift = (shift % 11) + 1
1331 continue
1332
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001333 def RESIZE_REQUIRE_CALC(
1334 length_in, length_out, stride, offset, shift
1335 ):
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001336 # Perform the pseudo loop to look for out of bounds
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001337 for pos in range(0, length_out):
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001338 a = pos * stride + offset
1339 ia = a >> shift
1340 ia0 = max(ia, 0)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001341 ia1 = min(ia + 1, length_in - 1)
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001342 if ia0 > ia1:
1343 # Found a problem value
1344 break
1345 return ia0, ia1
1346
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001347 iy0, iy1 = RESIZE_REQUIRE_CALC(
1348 ifm_shape[1], output_dims[0], stride_y, offset_y, shift
1349 )
1350 ix0, ix1 = RESIZE_REQUIRE_CALC(
1351 ifm_shape[2], output_dims[1], stride_x, offset_x, shift
1352 )
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001353 if ix0 > ix1 or iy0 > iy1:
1354 # Change the shift value and check again
1355 count += 1
1356 shift = (shift % 11) + 1
1357 continue
1358 break
1359
1360 if count >= 12:
1361 # Couldn't find a good set of values for this test, skip it
1362 continue
1363
Kevin Cheng550ccc52021-03-03 11:21:43 -08001364 stride = [stride_y, stride_x]
1365 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001366
1367 stride_fp = [0.0, 0.0]
1368 offset_fp = [0.0, 0.0]
1369
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001370 # Common for all data types
1371 if error_name is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 (
1373 shift,
1374 stride,
1375 stride_fp,
1376 offset,
1377 offset_fp,
1378 outputDTypeNew,
1379 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001380 testGen,
1381 error_name,
1382 mode,
1383 dtype,
1384 shapeList,
1385 outputDType,
1386 shift,
1387 stride,
1388 stride_fp,
1389 offset,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001390 offset_fp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001391 )
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001392 else:
1393 outputDTypeNew = outputDType
1394
1395 arg_list.append(
1396 (
1397 arg_str.format(
1398 "N" if mode == ResizeMode.NEAREST else "B",
1399 shift,
1400 output_dims[0],
1401 output_dims[1],
1402 testGen.typeStr(outputDTypeNew),
1403 stride_fp[0] if float_op else stride[0],
1404 stride_fp[1] if float_op else stride[1],
1405 offset_fp[0] if float_op else offset[0],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001406 offset_fp[1] if float_op else offset[1],
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001407 ),
1408 [
1409 mode,
1410 stride,
1411 offset,
1412 shift,
1413 stride_fp,
1414 offset_fp,
1415 output_dims,
1416 dtype,
1417 outputDTypeNew,
1418 ],
1419 )
1420 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001421
1422 return arg_list
1423
Kevin Chengfe392ce2021-10-18 21:51:55 +00001424 @staticmethod
1425 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1426 arg_list = []
1427
1428 if dtype == DType.INT8:
1429 table = np.int32(
1430 testGen.rng.integers(low=-128, high=128, size=[256])
1431 ).tolist()
1432 else: # INT16
1433 table = np.int32(
1434 testGen.rng.integers(low=-32768, high=32768, size=[513])
1435 ).tolist()
1436
1437 arg_list.append(
1438 (
1439 "",
1440 [table],
1441 )
1442 )
1443 return arg_list
1444
Matthew Haddon1c00b712021-10-01 15:51:03 +01001445 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001446 # CondIf generates the condition values here.
1447 # Convert to tensors in the build function, along with the
1448 # then and else blocks
1449 arg_list = []
1450
1451 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001452 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001453
1454 return arg_list
1455
Matthew Haddon1c00b712021-10-01 15:51:03 +01001456 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001457 # While loop: 0 iterations, 1, more than 1
1458 arg_list = []
1459
1460 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001461 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001462
1463 return arg_list
1464
Matthew Haddone86fd342021-09-07 16:12:21 +01001465
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001466class TosaErrorIfArgGen:
Matthew Haddone86fd342021-09-07 16:12:21 +01001467 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001468 def eiResizeErrorIf(
1469 testGen,
1470 error_name,
1471 mode,
1472 dtype,
1473 shapeList,
1474 outputDType,
1475 shift,
1476 stride,
1477 stride_fp,
1478 offset,
1479 offset_fp,
1480 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01001481
1482 if outputDType == DType.FLOAT:
1483 if error_name == ErrorIf.StrideSmallerEqualZero:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001484 stride_fp = testGen.rng.random(size=[2]) - 2
Matthew Haddone86fd342021-09-07 16:12:21 +01001485 elif error_name == ErrorIf.ShiftNotZero:
1486 shift = testGen.rng.integers(1, 5)
1487 elif error_name == ErrorIf.StrideLargerDimension:
1488 shape = shapeList[0]
1489 transform_height = testGen.rng.choice([False, True])
1490 if transform_height:
1491 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1492 else:
1493 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1494 else:
1495 if error_name == ErrorIf.StrideSmallerEqualZero:
1496 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1497 elif error_name == ErrorIf.ShiftSmallerOne:
1498 shift = testGen.rng.integers(-3, 1)
1499 if shift <= 0:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001500 stride = [
1501 (16 >> -shift) - 1,
1502 (16 >> -shift) - 1,
1503 ] # avoids other ERROR_IF checks
1504 offset = [
1505 (16 >> -shift) - 1,
1506 (16 >> -shift) - 1,
1507 ] # avoids other ERROR_IF checks
Matthew Haddone86fd342021-09-07 16:12:21 +01001508 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001509 stride = [
1510 (16 << shift) - 1,
1511 (16 << shift) - 1,
1512 ] # avoids other ERROR_IF checks
1513 offset = [
1514 (16 << shift) - 1,
1515 (16 << shift) - 1,
1516 ] # avoids other ERROR_IF checks
Matthew Haddone86fd342021-09-07 16:12:21 +01001517 elif error_name == ErrorIf.ShiftLargerEleven:
1518 shift = np.int16(testGen.rng.integers(12, 15))
1519 elif error_name == ErrorIf.StrideLargerDimension:
1520 shape = shapeList[0]
1521 transform_height = testGen.rng.choice([False, True])
1522 if transform_height:
1523 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1524 else:
1525 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1526 elif error_name == ErrorIf.StrideLargerEqualMax:
1527 stride = [(16 << shift) + 1, (16 << shift) + 1]
1528 elif error_name == ErrorIf.OffsetLargerEqualMax:
1529 offset = [(16 << shift) + 1, (16 << shift) + 1]
1530 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1531 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1532
Matthew Haddon848efb42021-09-09 12:30:53 +01001533 if error_name == ErrorIf.WrongOutputType:
1534 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001535 incorrect_types = (
1536 DType.INT4,
1537 DType.INT16,
1538 DType.INT32,
1539 DType.INT48,
1540 DType.FLOAT,
1541 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001542 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001543 incorrect_types = (
1544 DType.INT4,
1545 DType.INT8,
1546 DType.INT32,
1547 DType.INT48,
1548 DType.FLOAT,
1549 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001550 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001551 incorrect_types = (
1552 DType.INT4,
1553 DType.INT8,
1554 DType.INT16,
1555 DType.INT48,
1556 DType.FLOAT,
1557 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001558 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001559 incorrect_types = (
1560 DType.INT4,
1561 DType.INT8,
1562 DType.INT16,
1563 DType.INT32,
1564 DType.FLOAT,
1565 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001566 elif dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001567 incorrect_types = (
1568 DType.INT4,
1569 DType.INT8,
1570 DType.INT16,
1571 DType.INT32,
1572 DType.INT48,
1573 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001574 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001575
Matthew Haddon848efb42021-09-09 12:30:53 +01001576 return shift, stride, stride_fp, offset, offset_fp, outputDType
1577
1578 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001579 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001580 if (
1581 error_name == ErrorIf.StrideSmallerOne
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001582 # padding must not exceed the kernel size
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001583 and pad[0] < kernel[0]
1584 and pad[1] < kernel[0]
1585 and pad[2] < kernel[1]
1586 and pad[3] < kernel[1]
1587 ):
1588 wrongStride = (
1589 testGen.rng.choice([0, -1, -2, -3]),
1590 testGen.rng.choice([0, -1, -2, -3]),
1591 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001592 return wrongStride, pad, kernel
1593 elif error_name == ErrorIf.PadSmallerZero:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 wrongPad = (
1595 testGen.rng.choice([-1, -2, -3]),
1596 testGen.rng.choice([-1, -2, -3]),
1597 testGen.rng.choice([-1, -2, -3]),
1598 testGen.rng.choice([-1, -2, -3]),
1599 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001600 return stride, wrongPad, kernel
1601 elif error_name == ErrorIf.KernelSmallerOne:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001602 wrongKernel = (
1603 testGen.rng.choice([0, -1, -2, -3]),
1604 testGen.rng.choice([0, -1, -2, -3]),
1605 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001606 return stride, pad, wrongKernel
1607 elif error_name == ErrorIf.PadLargerEqualKernel:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001608 wrongPad = (
1609 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
1610 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
1611 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
1612 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
1613 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001614 return stride, wrongPad, kernel
1615 else:
1616 return None, None, None
1617
Matthew Haddonc2025212021-10-08 21:21:05 +01001618 @staticmethod
1619 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1620 if input_dtype == DType.INT8:
1621 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1622 return True
1623 if input_dtype in [DType.INT16, DType.INT32]:
1624 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1625 return True
1626 elif input_dtype == DType.INT48:
1627 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1628 return True
1629 elif input_dtype == DType.UINT8:
1630 if output_dtype != DType.INT8:
1631 return True
1632 return False
1633
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001634 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001635 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1636 # Mess up input/output tensors for ERROR_IF checks
1637 if error_name == "WrongInputList":
1638 add_input = testGen.rng.choice([True, False])
1639 if add_input:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001640 input_list.append("eiDummyInput")
Matthew Haddon848efb42021-09-09 12:30:53 +01001641 else:
1642 input_list = input_list[:-1]
Les Bell0e027d42021-11-09 14:42:14 +00001643 elif error_name == "WrongOutputList":
Matthew Haddon848efb42021-09-09 12:30:53 +01001644 add_output = testGen.rng.choice([True, False])
1645 if add_output:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001646 output_list.append("eiDummyOutput")
Matthew Haddon848efb42021-09-09 12:30:53 +01001647 else:
1648 output_list = []
1649 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001650
Matthew Haddonc2025212021-10-08 21:21:05 +01001651 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01001652 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001653 """Restrict the dimensions and overall size of a shape to
1654 max_dim and max_items.
1655 """
Matthew Haddon630c17c2021-10-14 15:05:41 +01001656 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
1657 while product(new_shape) > max_items:
1658 new_shape = [max(d - 1, 1) for d in new_shape]
1659 return new_shape
Matthew Haddone807aae2021-10-11 18:12:58 +01001660
1661 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1662 if error_name == ErrorIf.StartSmallerZero:
1663 newStart = []
1664 for i in range(len(input_shape)):
1665 newStart.append(testGen.rng.choice([-3, -2, -1]))
1666 return newStart, size
1667 elif error_name == ErrorIf.SizeSmallerEqualZero:
1668 newSize = []
1669 for i in range(len(input_shape)):
1670 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1671 return start, newSize
1672 elif error_name == ErrorIf.StartSizeOutsideBounds:
1673 newStart, newSize = [], []
1674 for i in range(len(input_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001675 newStart.append(input_shape[i] - 1)
Matthew Haddone807aae2021-10-11 18:12:58 +01001676 newSize.append(testGen.rng.choice([2, 3, 4]))
1677 return newStart, newSize
1678 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1679 remove = testGen.rng.choice([True, False])
1680 if remove:
1681 newStart = start[1:]
1682 newSize = size[1:]
1683 else:
1684 newStart = start
1685 newStart.append(1)
1686 newSize = size
1687 newSize.append(1)
1688 return newStart, newSize
1689 else:
1690 return start, size
1691
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001692 @staticmethod
1693 def eiCastErrorIf(testGen, input_dtype):
1694 if input_dtype in [DType.BOOL, DType.FLOAT]:
1695 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
1696 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
1697 outputDType = [DType.INT48]
1698 else:
1699 assert True, f"input_dtype ({input_dtype}) not supported"
1700 return outputDType
1701
1702
Matthew Haddone86fd342021-09-07 16:12:21 +01001703class TosaErrorValidator:
Matthew Haddon848efb42021-09-09 12:30:53 +01001704 @staticmethod
1705 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
Les Bell729b0352021-11-24 10:28:21 +00001706 """Check ERROR_IF statements are caught and set the expected result.
1707
1708 Args:
1709 serializer: the serializer to set the expected result in
1710 validator_fcns: a sequence of validator functions to verify the result
1711 error_name: the name of the ERROR_IF condition to check for
1712 kwargs: keyword arguments for the validator functions
1713 Returns:
1714 True if the result matches the expected result; otherwise False
1715 """
1716 overall_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001717 for val_fcn in validator_fcns:
1718 val_result = val_fcn(True, **kwargs)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001719 validator_name = val_result["error_name"]
1720 error_result = val_result["error_result"]
1721 error_reason = val_result["error_reason"]
Matthew Haddon848efb42021-09-09 12:30:53 +01001722
Les Bell0e027d42021-11-09 14:42:14 +00001723 # expect an error IFF the error_name and validator_name match
1724 expected_result = error_result == (error_name == validator_name)
Les Bell729b0352021-11-24 10:28:21 +00001725 overall_result &= expected_result
Les Bell0e027d42021-11-09 14:42:14 +00001726
1727 if expected_result and error_result:
Jeremy Johnson2ec34942021-12-14 16:34:05 +00001728 serializer.setExpectedReturnCode(2, True, desc=error_reason)
Les Bell0e027d42021-11-09 14:42:14 +00001729 elif error_result: # and not expected_result
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001730 print(
1731 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1732 f" Expected: {error_name}, Got: {validator_name}"
1733 )
1734 elif not expected_result: # and not error_result
1735 print(
1736 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1737 f" Expected: {error_name}"
1738 )
Les Bell0e027d42021-11-09 14:42:14 +00001739
1740 if not expected_result:
1741 for k, v in sorted(kwargs.items()):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001742 if k != "op":
1743 if k.endswith("dtype"):
Les Bell0e027d42021-11-09 14:42:14 +00001744 v = valueToName(DType, v)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001745 print(f" {k} = {v}")
Matthew Haddon848efb42021-09-09 12:30:53 +01001746
Les Bell729b0352021-11-24 10:28:21 +00001747 return overall_result
1748
Matthew Haddon848efb42021-09-09 12:30:53 +01001749 @staticmethod
1750 def evWrongInputType(check=False, **kwargs):
Les Bell0e027d42021-11-09 14:42:14 +00001751 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001752
1753 # Find the unsupported input data types
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001754 op = kwargs["op"]
1755 input_dtypes = op["types"]
1756 allowed_input_dtypes = {
1757 t[0] if isinstance(t, list) else t for t in input_dtypes
1758 }
Les Bell0e027d42021-11-09 14:42:14 +00001759 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
Matthew Haddon848efb42021-09-09 12:30:53 +01001760
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001761 if op["op"] == Op.CLAMP:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001762 wrong_input_dtypes.remove(DType.INT48)
1763
Matthew Haddon848efb42021-09-09 12:30:53 +01001764 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001765 input_dtype = kwargs["input_dtype"]
Les Bell0e027d42021-11-09 14:42:14 +00001766 if input_dtype not in allowed_input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001767 error_result = True
1768
1769 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001770 "error_name": ErrorIf.WrongInputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001771 "error_result": error_result,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001772 "error_reason": "Input data type not supported for this operator",
1773 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
Matthew Haddon848efb42021-09-09 12:30:53 +01001774 }
1775 return info_dict
1776
1777 @staticmethod
1778 def evWrongOutputType(check=False, **kwargs):
Matthew Haddon848efb42021-09-09 12:30:53 +01001779 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001780
1781 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001782 input_dtype = kwargs["input_dtype"]
1783 output_dtype = kwargs["output_dtype"]
1784 op = kwargs["op"]
Matthew Haddon848efb42021-09-09 12:30:53 +01001785
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 if op["op"] == Op.RESIZE:
1787 mode = kwargs["mode"]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001788 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001789 (
1790 mode == ResizeMode.NEAREST
1791 and input_dtype == DType.INT8
1792 and output_dtype != DType.INT8
1793 )
1794 or (
1795 mode == ResizeMode.NEAREST
1796 and input_dtype == DType.INT16
1797 and output_dtype != DType.INT16
1798 )
1799 or (
1800 mode == ResizeMode.BILINEAR
1801 and input_dtype == DType.INT8
1802 and output_dtype != DType.INT32
1803 )
1804 or (
1805 mode == ResizeMode.BILINEAR
1806 and input_dtype == DType.INT16
1807 and output_dtype != DType.INT48
1808 )
1809 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001810 ):
1811 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001812
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 elif op["op"] == Op.RESCALE:
Matthew Haddonc2025212021-10-08 21:21:05 +01001814 if input_dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001815 if output_dtype not in [
1816 DType.UINT8,
1817 DType.INT8,
1818 DType.INT16,
1819 DType.INT32,
1820 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001821 error_result = True
1822 if input_dtype in [DType.INT16, DType.INT32]:
1823 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1824 error_result = True
1825 elif input_dtype == DType.INT48:
1826 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1827 error_result = True
1828 elif input_dtype == DType.UINT8:
1829 if output_dtype != DType.INT8:
1830 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001831
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001832 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001833 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001834 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
1835 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
1836 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001837 ):
1838 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001839
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001840 elif op["op"] == Op.ARGMAX:
1841 if (
1842 input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
1843 and output_dtype != DType.INT32
1844 ):
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001845 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001846
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001847 elif op["op"] == Op.MUL:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001848 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
1849 error_result = True
1850 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1851 error_result = True
1852
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001853 elif op["op"] == Op.TABLE:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001854 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
1855 error_result = True
1856 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
1857 error_result = True
1858
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001859 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001860 if output_dtype != DType.BOOL:
1861 error_result = True
1862
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001863 elif op["op"] == Op.CAST:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001864 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001865 (
1866 input_dtype == DType.BOOL
1867 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
1868 )
1869 or (
1870 input_dtype == DType.INT8
1871 and output_dtype
1872 not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
1873 )
1874 or (
1875 input_dtype == DType.INT16
1876 and output_dtype
1877 not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
1878 )
1879 or (
1880 input_dtype == DType.INT32
1881 and output_dtype
1882 not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
1883 )
1884 or (
1885 input_dtype == DType.FLOAT
1886 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
1887 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001888 ):
1889 error_result = True
1890
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001891 elif op["op"] in {
1892 Op.CONV2D,
1893 Op.CONV3D,
1894 Op.DEPTHWISE_CONV2D,
1895 Op.TRANSPOSE_CONV2D,
1896 }:
Les Bell0e027d42021-11-09 14:42:14 +00001897 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001898 input_dtype == DType.INT8
1899 and output_dtype != DType.INT32
1900 or input_dtype == DType.INT16
1901 and output_dtype != DType.INT48
1902 or input_dtype == DType.FLOAT
1903 and output_dtype != DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00001904 ):
1905 error_result = True
1906 # invalid input types are ignored, to avoid reporting multiple errors
1907
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001908 else:
1909 if output_dtype != input_dtype:
1910 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001911
1912 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001913 "error_name": ErrorIf.WrongOutputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001914 "error_result": error_result,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001915 "error_reason": (
1916 "Output data type not supported for this configuration of operator"
1917 ),
1918 "param_reqs": {"rank": None, "dtype": None, "shape": None},
Matthew Haddon848efb42021-09-09 12:30:53 +01001919 }
1920 return info_dict
1921
1922 @staticmethod
1923 def evWrongRank(check=False, **kwargs):
1924 all_ranks = (1, 2, 3, 4, 5)
1925
1926 # Make a list of incorrect ranks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001927 assert "op" in kwargs
1928 op = kwargs["op"]
1929 rmin, rmax = op["rank"]
Matthew Haddon848efb42021-09-09 12:30:53 +01001930 rank_range = range(rmin, rmax + 1)
1931 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001932 # Remove small incorrect ranks to avoid index errors
1933 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001934 # Set minimum incorrect rank to 3 to avoid index error
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001935 if op["op"] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001936 incorrect_ranks = [3, 5]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001937 elif op["op"] in [Op.TRANSPOSE]:
Matthew Haddon01c359d2021-10-15 16:30:48 +01001938 incorrect_ranks = [7, 8]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001939 elif op["op"] in [Op.CONV3D]:
Les Bell0e027d42021-11-09 14:42:14 +00001940 incorrect_ranks = [6, 7]
Matthew Haddon848efb42021-09-09 12:30:53 +01001941
1942 error_name = ErrorIf.WrongRank
1943 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1944 error_result = False
1945 error_reason = "Rank not supported for this operator"
1946
1947 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001948 input_shape = kwargs["input_shape"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001949
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001950 if (
1951 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
1952 and len(input_shape) != 4
1953 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01001954 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001955 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001956 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001957 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001958 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001959 else:
1960 if len(input_shape) not in rank_range:
1961 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001962
1963 info_dict = {
1964 "error_name": error_name,
1965 "error_result": error_result,
1966 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001967 "param_reqs": param_reqs,
Matthew Haddon848efb42021-09-09 12:30:53 +01001968 }
1969 return info_dict
1970
1971 @staticmethod
1972 def evWrongInputList(check=False, **kwargs):
1973 error_name = ErrorIf.WrongInputList
1974 param_reqs = {"rank": None, "dtype": None, "shape": None}
1975 error_result = False
1976 error_reason = "Op input list does not match expected input"
1977
1978 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001979 op = kwargs["op"]
1980 input_list = kwargs["input_list"]
1981 num_operands = kwargs["num_operands"]
1982 if op["op"] in [Op.SCATTER, Op.GATHER]:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001983 # SCATTER/GATHER add an indices input tensor in their build functions
1984 num_operands += 1
Kevin Chengfe392ce2021-10-18 21:51:55 +00001985 if len(input_list) != num_operands:
1986 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001987
1988 info_dict = {
1989 "error_name": error_name,
1990 "error_result": error_result,
1991 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001992 "param_reqs": param_reqs,
Matthew Haddon848efb42021-09-09 12:30:53 +01001993 }
1994 return info_dict
1995
1996 @staticmethod
1997 def evWrongOutputList(check=False, **kwargs):
1998 error_name = ErrorIf.WrongOutputList
1999 param_reqs = {"rank": None, "dtype": None, "shape": None}
2000 error_result = False
2001 error_reason = "Op output list does not match expected output"
2002
2003 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002004 output_list = kwargs["output_list"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002005 # Note this will be incorrect if an operator returns more than one output
2006 if len(output_list) != 1:
2007 error_result = True
2008
2009 info_dict = {
2010 "error_name": error_name,
2011 "error_result": error_result,
2012 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002013 "param_reqs": param_reqs,
Matthew Haddon848efb42021-09-09 12:30:53 +01002014 }
2015 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01002016
2017 @staticmethod
2018 def evMaxDimExceeded(check=False, **kwargs):
2019 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01002020 param_reqs = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002021 "rank": [4, 4],
Matthew Haddon848efb42021-09-09 12:30:53 +01002022 "dtype": [DType.INT8],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002023 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
2024 }
Matthew Haddone86fd342021-09-07 16:12:21 +01002025 error_result = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002026 error_reason = (
2027 "At least one maximum dimension is greater than or equal to 16384"
2028 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002029
2030 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002031 input_shape = kwargs["input_shape"]
2032 output_shape = kwargs["output_shape"] # Note this is just (OH, OW)
2033 if (
2034 (input_shape[1] >= 16384)
2035 or (input_shape[2] >= 16384)
2036 or (output_shape[0] >= 16384)
2037 or (output_shape[1] >= 16384)
2038 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002039 error_result = True
2040
2041 info_dict = {
2042 "error_name": error_name,
2043 "error_result": error_result,
2044 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002045 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002046 }
2047 return info_dict
2048
2049 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002050 def evBatchMismatch(check=False, **kwargs):
2051 error_name = ErrorIf.BatchMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002052 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002053 error_result = False
2054 error_reason = "Input batch size not equal to output batch size"
2055
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002056 assert "op" in kwargs
2057 op = kwargs["op"]
2058 rmin, rmax = op["rank"]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002059 rank_range = range(rmin, rmax + 1)
2060
2061 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002062 input_shape = kwargs["input_shape"]
2063 output_shape = kwargs[
2064 "result_tensor"
2065 ].shape # Note this is just (N, OH, OW, C)
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002066
2067 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
2068 error_result = True
2069
2070 info_dict = {
2071 "error_name": error_name,
2072 "error_result": error_result,
2073 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002074 "param_reqs": param_reqs,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002075 }
2076 return info_dict
2077
2078 @staticmethod
2079 def evChannelMismatch(check=False, **kwargs):
2080 error_name = ErrorIf.ChannelMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002081 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002082 error_result = False
2083 error_reason = "Input channel size not equal to output channel size"
2084
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002085 assert "op" in kwargs
2086 op = kwargs["op"]
2087 rmin, rmax = op["rank"]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002088 rank_range = range(rmin, rmax + 1)
2089
2090 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002091 input_shape = kwargs["input_shape"]
2092 output_shape = kwargs[
2093 "result_tensor"
2094 ].shape # Note this is just (N, OH, OW, C)
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002095 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
2096 error_result = True
2097
2098 info_dict = {
2099 "error_name": error_name,
2100 "error_result": error_result,
2101 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002102 "param_reqs": param_reqs,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002103 }
2104 return info_dict
2105
2106 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01002107 def evStrideSmallerEqualZero(check=False, **kwargs):
2108 error_name = ErrorIf.StrideSmallerEqualZero
2109 param_reqs = {"rank": None, "dtype": None, "shape": None}
2110 error_result = False
2111 error_reason = "Stride value smaller than or equal zero"
2112
2113 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002114 input_dtype = kwargs["input_dtype"]
2115 output_dtype = kwargs["output_dtype"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002116 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002117 stride = kwargs["stride"] # Work around wrong input/output type tests
Matthew Haddon848efb42021-09-09 12:30:53 +01002118 elif output_dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002119 stride = kwargs["stride_fp"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002120 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002121 stride = kwargs[
2122 "stride_fp"
2123 ] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01002124 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002125 stride = kwargs["stride"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002126
2127 if min(stride) <= 0:
2128 error_result = True
2129
2130 info_dict = {
2131 "error_name": error_name,
2132 "error_result": error_result,
2133 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002134 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002135 }
2136 return info_dict
2137
2138 @staticmethod
2139 def evStrideLargerEqualMax(check=False, **kwargs):
2140 error_name = ErrorIf.StrideLargerEqualMax
2141 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2142 error_result = False
2143 error_reason = "Stride value larger than or equal to maximum value"
2144
2145 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002146 shift = kwargs["shift"]
2147 input_dtype = kwargs["input_dtype"]
2148 stride = kwargs["stride"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002149 if input_dtype in [DType.INT8, DType.INT16]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002150 if shift >= 0 and (
2151 stride[0] >= (16 << shift) or stride[1] >= (16 << shift)
2152 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002153 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002154 elif shift < 0 and (
2155 stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)
2156 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002157 error_result = True
2158
2159 info_dict = {
2160 "error_name": error_name,
2161 "error_result": error_result,
2162 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002163 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002164 }
2165 return info_dict
2166
Matthew Haddone86fd342021-09-07 16:12:21 +01002167 @staticmethod
2168 def evStrideLargerDimension(check=False, **kwargs):
2169 error_name = ErrorIf.StrideLargerDimension
2170 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
2171 error_result = False
2172 error_reason = "Stride value larger than or equal to H/W dimension"
2173
2174 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002175 shape = kwargs["input_shape"]
2176 input_dtype = kwargs["input_dtype"]
2177 stride = kwargs["stride_fp"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002178
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002179 if (
2180 input_dtype == DType.FLOAT
2181 and (stride[0] > shape[1])
2182 or (stride[1] > shape[2])
2183 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002184 error_result = True
2185
2186 info_dict = {
2187 "error_name": error_name,
2188 "error_result": error_result,
2189 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002190 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002191 }
2192 return info_dict
2193
Matthew Haddone86fd342021-09-07 16:12:21 +01002194 @staticmethod
2195 def evOffsetSmallerEqualMin(check=False, **kwargs):
2196 error_name = ErrorIf.OffsetSmallerEqualMin
2197 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2198 error_result = False
2199 error_reason = "Offset value smaller than or equal to minimum value"
2200
2201 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002202 shift = kwargs["shift"]
2203 output_dtype = kwargs["output_dtype"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002204 if output_dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002205 offset = kwargs["offset_fp"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002206 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002207 offset = kwargs["offset"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002208
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002209 if shift >= 0 and (
2210 offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)
2211 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002212 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002213 elif shift < 0 and (
2214 offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)
2215 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002216 error_result = True
2217
2218 info_dict = {
2219 "error_name": error_name,
2220 "error_result": error_result,
2221 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002222 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002223 }
2224 return info_dict
2225
2226 @staticmethod
2227 def evOffsetLargerEqualMax(check=False, **kwargs):
2228 error_name = ErrorIf.OffsetLargerEqualMax
2229 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2230 error_result = False
2231 error_reason = "Offset value larger than or equal to maximum value"
2232
2233 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002234 shift = kwargs["shift"]
2235 output_dtype = kwargs["output_dtype"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002236 if output_dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002237 offset = kwargs["offset_fp"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002238 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002239 offset = kwargs["offset"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002240
2241 if shift >= 0:
2242 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
2243 error_result = True
2244
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002245 if shift >= 0 and (
2246 offset[0] >= (16 << shift) or offset[1] >= (16 << shift)
2247 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002248 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002249 elif shift < 0 and (
2250 offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)
2251 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002252 error_result = True
2253
2254 info_dict = {
2255 "error_name": error_name,
2256 "error_result": error_result,
2257 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002258 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002259 }
2260 return info_dict
2261
2262 @staticmethod
2263 def evShiftNotZero(check=False, **kwargs):
2264 error_name = ErrorIf.ShiftNotZero
2265 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
2266 error_result = False
2267 error_reason = "Shift value must be zero for float input"
2268
2269 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002270 shift = kwargs["shift"]
2271 input_dtype = kwargs["input_dtype"]
2272 output_dtype = kwargs["output_dtype"]
2273 if (
2274 input_dtype == DType.FLOAT
2275 and output_dtype == DType.FLOAT
2276 and shift != 0
2277 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002278 error_result = True
2279
2280 info_dict = {
2281 "error_name": error_name,
2282 "error_result": error_result,
2283 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002284 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002285 }
2286 return info_dict
2287
Matthew Haddone86fd342021-09-07 16:12:21 +01002288 @staticmethod
2289 def evShiftSmallerOne(check=False, **kwargs):
2290 error_name = ErrorIf.ShiftSmallerOne
2291 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2292 error_result = False
2293 error_reason = "Shift value smaller than one"
2294
2295 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002296 shift = kwargs["shift"]
2297 input_dtype = kwargs["input_dtype"]
2298 output_dtype = kwargs["output_dtype"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002299 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01002300 error_result = True
2301
2302 info_dict = {
2303 "error_name": error_name,
2304 "error_result": error_result,
2305 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002306 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002307 }
2308 return info_dict
2309
2310 @staticmethod
2311 def evShiftLargerEleven(check=False, **kwargs):
2312 error_name = ErrorIf.ShiftLargerEleven
2313 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2314 error_result = False
2315 error_reason = "Shift value larger than eleven"
2316
2317 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002318 shift = kwargs["shift"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002319 if shift > 11:
2320 error_result = True
2321
2322 info_dict = {
2323 "error_name": error_name,
2324 "error_result": error_result,
2325 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002326 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002327 }
2328 return info_dict
2329
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002330 @staticmethod
2331 def evRankMismatch(check=False, **kwargs):
2332 error_name = ErrorIf.RankMismatch
2333 param_reqs = {"rank": None, "dtype": None, "shape": None}
2334 error_result = False
2335 error_reason = "Input Rank does not match output rank"
2336
2337 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002338 input1_shape = kwargs["input1"].shape
2339 input2_shape = kwargs["input2"].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002340 # In case of SELECT op
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002341 input3_shape = (
2342 kwargs["input3"].shape if "input3" in kwargs else input2_shape
2343 )
2344 output_shape = kwargs["result_tensor"].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002345 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002346 (len(input1_shape) != len(output_shape))
2347 or (len(input2_shape) != len(output_shape))
2348 or (len(input3_shape) != len(output_shape))
2349 ):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002350 error_result = True
2351
2352 info_dict = {
2353 "error_name": error_name,
2354 "error_result": error_result,
2355 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002356 "param_reqs": param_reqs,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002357 }
2358 return info_dict
2359
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002360 @staticmethod
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002361 def evDimensionMismatch(check=False, **kwargs):
2362 error_name = ErrorIf.DimensionMismatch
2363 param_reqs = {"rank": None, "dtype": None, "shape": None}
2364 error_result = False
2365 error_reason = "Input Dimensions do not match output"
2366
2367 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002368 input1_shape = kwargs["input1"].shape
2369 input2_shape = kwargs["input2"].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002370 # In case of SELECT op
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002371 input3_shape = (
2372 kwargs["input3"].shape if "input3" in kwargs else input2_shape
2373 )
2374 output_shape = kwargs["result_tensor"].shape
2375 for i in range(
2376 min(len(input1_shape), len(input2_shape), len(input3_shape))
2377 ):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002378 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002379 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
2380 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
2381 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
2382 ):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002383 error_result = True
2384
2385 info_dict = {
2386 "error_name": error_name,
2387 "error_result": error_result,
2388 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002389 "param_reqs": param_reqs,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002390 }
2391 return info_dict
2392
2393 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002394 def evInputZeroPointNotZero(check=False, **kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002395 op = kwargs["op"]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002396 error_result = False
Les Bell0e027d42021-11-09 14:42:14 +00002397
2398 # Quantizable types
2399 qTypes = (DType.INT8, DType.UINT8)
2400
2401 # This does not apply to quantizable types
2402 inputDtypes = [
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002403 dtype
2404 for dtype in op["types"]
2405 if (isinstance(dtype, list) and dtype[0] not in qTypes)
2406 or (not isinstance(dtype, list) and dtype not in qTypes)
Les Bell0e027d42021-11-09 14:42:14 +00002407 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002408
2409 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002410 input_dtype = kwargs["input_dtype"]
2411 if isinstance(kwargs["qinfo"], tuple):
2412 qinfo = kwargs["qinfo"]
Matthew Haddonc2025212021-10-08 21:21:05 +01002413 input_zero_point = qinfo[0]
2414 else:
2415 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002416 qinfo = kwargs["qinfo"].ints
Matthew Haddonc2025212021-10-08 21:21:05 +01002417 input_zero_point = qinfo[0][1]
2418
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002419 if op["op"] == Op.MATMUL:
2420 qinfo = kwargs["qinfo"].ints
Les Bell0e027d42021-11-09 14:42:14 +00002421 for dtype, zp in (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002422 (kwargs["input_dtype"], qinfo[0][1]),
2423 (kwargs["input2_dtype"], qinfo[1][1]),
Les Bell0e027d42021-11-09 14:42:14 +00002424 ):
2425 if dtype not in qTypes and zp != 0:
2426 error_result = True
2427 break
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002428 else:
Les Bell0e027d42021-11-09 14:42:14 +00002429 error_result = input_dtype not in qTypes and input_zero_point != 0
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002430
2431 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00002432 "error_name": ErrorIf.InputZeroPointNotZero,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002433 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00002434 "error_reason": "Input DType not INT8 and zero point not 0",
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002435 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002436 }
2437 return info_dict
2438
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002439 @staticmethod
2440 def evWeightZeroPointNotZero(check=False, **kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002441 op = kwargs["op"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002442
2443 # exclude inputs with INT8 weights
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002444 inputDtypes = [
2445 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
2446 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002447
2448 error_name = ErrorIf.WeightZeroPointNotZero
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002449 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002450 error_result = False
2451 error_reason = "Weight DType not INT8 and zero point not 0"
2452
2453 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002454 weight_dtype = kwargs["weight_dtype"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002455 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002456 qinfo = kwargs["qinfo"].ints
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002457 weight_zero_point = qinfo[1][1]
2458 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002459 error_result = True
2460
2461 info_dict = {
2462 "error_name": error_name,
2463 "error_result": error_result,
2464 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002465 "param_reqs": param_reqs,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002466 }
2467 return info_dict
2468
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002469 @staticmethod
2470 def evOutputZeroPointNotZero(check=False, **kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002471 op = kwargs["op"]
2472 inputDtypes = op["types"].copy()
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002473 if DType.INT8 in inputDtypes:
2474 inputDtypes.remove(DType.INT8)
2475 if DType.UINT8 in inputDtypes:
2476 inputDtypes.remove(DType.UINT8)
2477
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002478 error_name = ErrorIf.OutputZeroPointNotZero
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002479 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002480 error_result = False
2481 error_reason = "Output DType not INT8 and zero point not 0"
2482
2483 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002484 input_dtype = kwargs["input_dtype"]
2485 output_dtype = kwargs["output_dtype"]
2486 if isinstance(kwargs["qinfo"], tuple):
2487 qinfo = kwargs["qinfo"]
Matthew Haddonc2025212021-10-08 21:21:05 +01002488 output_zero_point = qinfo[1]
2489 else:
2490 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002491 qinfo = kwargs["qinfo"].ints
Matthew Haddonc2025212021-10-08 21:21:05 +01002492 output_zero_point = qinfo[1][1]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002493 if op["op"] == Op.AVG_POOL2D:
Matthew Haddonc2025212021-10-08 21:21:05 +01002494 if input_dtype != DType.INT8 and output_zero_point != 0:
2495 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002496 elif (
2497 output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
2498 ):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002499 error_result = True
2500
2501 info_dict = {
2502 "error_name": error_name,
2503 "error_result": error_result,
2504 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002505 "param_reqs": param_reqs,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002506 }
2507 return info_dict
2508
Matthew Haddond6ce7252021-09-29 15:35:44 +01002509 @staticmethod
2510 def evAxisSmallerZero(check=False, **kwargs):
2511 error_name = ErrorIf.AxisSmallerZero
2512 param_reqs = {"rank": None, "dtype": None, "shape": None}
2513 error_result = False
2514 error_reason = "Axis smaller than zero"
2515
2516 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002517 axis = kwargs["axis"]
Matthew Haddond6ce7252021-09-29 15:35:44 +01002518 if axis < 0:
2519 error_result = True
2520
2521 info_dict = {
2522 "error_name": error_name,
2523 "error_result": error_result,
2524 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002525 "param_reqs": param_reqs,
Matthew Haddond6ce7252021-09-29 15:35:44 +01002526 }
2527 return info_dict
2528
Matthew Haddond6ce7252021-09-29 15:35:44 +01002529 @staticmethod
2530 def evAxisLargerRank(check=False, **kwargs):
2531 error_name = ErrorIf.AxisLargerRank
2532 param_reqs = {"rank": None, "dtype": None, "shape": None}
2533 error_result = False
2534 error_reason = "Axis larger than rank"
2535
2536 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002537 axis = kwargs["axis"]
2538 shape = kwargs["input_shape"]
Matthew Haddond6ce7252021-09-29 15:35:44 +01002539 if axis > len(shape):
2540 error_result = True
2541
2542 info_dict = {
2543 "error_name": error_name,
2544 "error_result": error_result,
2545 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002546 "param_reqs": param_reqs,
Matthew Haddond6ce7252021-09-29 15:35:44 +01002547 }
2548 return info_dict
2549
Matthew Haddond6ce7252021-09-29 15:35:44 +01002550 @staticmethod
2551 def evShapeOfAxisNotOne(check=False, **kwargs):
2552 error_name = ErrorIf.ShapeOfAxisNotOne
2553 param_reqs = {"rank": None, "dtype": None, "shape": None}
2554 error_result = False
2555 error_reason = "shape[axis] is not equal to 1"
2556
2557 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002558 axis = kwargs["axis"]
2559 shape = kwargs["output_shape"]
Matthew Haddond6ce7252021-09-29 15:35:44 +01002560 if (0 <= axis < len(shape)) and shape[axis] != 1:
2561 error_result = True
2562
2563 info_dict = {
2564 "error_name": error_name,
2565 "error_result": error_result,
2566 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002567 "param_reqs": param_reqs,
Matthew Haddond6ce7252021-09-29 15:35:44 +01002568 }
2569 return info_dict
2570
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002571 @staticmethod
2572 def evPadSmallerZero(check=False, **kwargs):
2573 error_name = ErrorIf.PadSmallerZero
2574 param_reqs = {"rank": None, "dtype": None, "shape": None}
2575 error_result = False
2576 error_reason = "At least one pad is smaller than zero"
2577
2578 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002579 op = kwargs["op"]
2580 pad = kwargs["pad"]
2581 if op["op"] == Op.PAD:
Matthew Haddone807aae2021-10-11 18:12:58 +01002582 for padding in pad:
2583 if min(padding) < 0:
2584 error_result = True
2585 else:
2586 if min(pad) < 0:
2587 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002588
2589 info_dict = {
2590 "error_name": error_name,
2591 "error_result": error_result,
2592 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002593 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002594 }
2595 return info_dict
2596
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002597 @staticmethod
2598 def evPadLargerEqualKernel(check=False, **kwargs):
2599 error_name = ErrorIf.PadLargerEqualKernel
2600 param_reqs = {"rank": None, "dtype": None, "shape": None}
2601 error_result = False
2602 error_reason = "At least one pad is larger than kernel dimension"
2603
2604 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002605 pad = kwargs["pad"]
2606 kernel = kwargs["kernel"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002607 if min(pad) > 0 and min(kernel) > 1:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002608 if (
2609 pad[0] >= kernel[0]
2610 or pad[1] >= kernel[0]
2611 or pad[2] >= kernel[1]
2612 or pad[3] >= kernel[1]
2613 ):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002614 error_result = True
2615
2616 info_dict = {
2617 "error_name": error_name,
2618 "error_result": error_result,
2619 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002620 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002621 }
2622 return info_dict
2623
2624 @staticmethod
2625 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2626 error_name = ErrorIf.PoolingOutputShapeMismatch
2627 param_reqs = {"rank": None, "dtype": None, "shape": None}
2628 error_result = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002629 error_reason = (
2630 "Mismatch between output shape provided and expected output shape"
2631 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002632
2633 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002634 pad = kwargs["pad"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002635 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2636
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002637 kernel = kwargs["kernel"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002638 kernel_y, kernel_x = kernel[0], kernel[1]
2639
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002640 input_shape = kwargs["input_shape"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002641 IH, IW = input_shape[1], input_shape[2]
2642
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002643 output_shape = kwargs["output_shape"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002644 OH, OW = output_shape[1], output_shape[2]
2645
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002646 stride = kwargs["stride"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002647 stride_y, stride_x = stride[0], stride[1]
2648
2649 # calculate correct height, width dimensions
2650 if stride_x != 0 and stride_y != 0:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002651 y_correct = (
2652 IH + pad_top + pad_bottom + stride_y - kernel_y
2653 ) // stride_y
2654 x_correct = (
2655 IW + pad_left + pad_right + stride_x - kernel_x
2656 ) // stride_x
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002657
2658 # ensure parameters are valid
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002659 params_valid = (
2660 min(kernel) >= 1
2661 and min(stride) >= 1
2662 and min(pad) >= 0
2663 and not (
2664 pad[0] >= kernel[0]
2665 or pad[1] >= kernel[0]
2666 or pad[2] >= kernel[1]
2667 or pad[3] >= kernel[1]
2668 )
2669 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002670
2671 if params_valid and (OH != y_correct or OW != x_correct):
2672 error_result = True
2673
2674 info_dict = {
2675 "error_name": error_name,
2676 "error_result": error_result,
2677 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002678 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002679 }
2680 return info_dict
2681
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002682 @staticmethod
2683 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2684 error_name = ErrorIf.ArgmaxOutputShapeMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002685 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002686 error_result = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002687 error_reason = (
2688 "Mismatch between output shape provided and expected output shape"
2689 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002690
2691 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002692 output_shape = kwargs["output_shape"]
2693 input_shape = kwargs["input_shape"]
2694 axis = kwargs["axis"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002695
2696 dimension_match = True
2697 axis_shift = 0
2698
2699 # Check that rank is correct before trying to check dimensions
2700 if (len(input_shape) - 1) == len(output_shape):
2701 for i in range(len(input_shape)):
2702 if i == axis:
2703 axis_shift = 1
2704 continue
2705 if input_shape[i] != output_shape[i - axis_shift]:
2706 dimension_match = False
2707
2708 if not dimension_match:
2709 error_result = True
2710
2711 info_dict = {
2712 "error_name": error_name,
2713 "error_result": error_result,
2714 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002715 "param_reqs": param_reqs,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002716 }
2717 return info_dict
2718
2719 @staticmethod
2720 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2721 error_name = ErrorIf.ArgmaxOutputRankMismatch
2722 param_reqs = {"rank": None, "dtype": None, "shape": None}
2723 error_result = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002724 error_reason = (
2725 "Mismatch between output shape provided and expected output shape"
2726 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002727
2728 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002729 output_shape = kwargs["output_shape"]
2730 input_shape = kwargs["input_shape"]
2731 axis = kwargs["axis"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002732 valid_params = axis >= 0 and axis < len(input_shape)
2733
2734 if valid_params and (len(input_shape) - 1) != len(output_shape):
2735 error_result = True
2736
2737 info_dict = {
2738 "error_name": error_name,
2739 "error_result": error_result,
2740 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002741 "param_reqs": param_reqs,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002742 }
2743 return info_dict
2744
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002745 @staticmethod
2746 def evKernelSmallerOne(check=False, **kwargs):
2747 error_name = ErrorIf.KernelSmallerOne
2748 param_reqs = {"rank": None, "dtype": None, "shape": None}
2749 error_result = False
2750 error_reason = "At least one kernel dimension is smaller than zero"
2751
2752 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002753 kernel = kwargs["kernel"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002754 if min(kernel) < 1:
2755 error_result = True
2756
2757 info_dict = {
2758 "error_name": error_name,
2759 "error_result": error_result,
2760 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002761 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002762 }
2763 return info_dict
2764
2765 @staticmethod
2766 def evStrideSmallerOne(check=False, **kwargs):
2767 error_name = ErrorIf.StrideSmallerOne
2768 param_reqs = {"rank": None, "dtype": None, "shape": None}
2769 error_result = False
2770 error_reason = "At least one stride dimension is smaller than zero"
2771
2772 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002773 stride = kwargs["stride"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002774 if min(stride) < 1:
2775 error_result = True
2776
2777 info_dict = {
2778 "error_name": error_name,
2779 "error_result": error_result,
2780 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002781 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002782 }
2783 return info_dict
2784
Matthew Haddonc2025212021-10-08 21:21:05 +01002785 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00002786 def evDilationSmallerOne(check=False, **kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002787 error_result = check and min(kwargs["dilation"]) < 1
Les Bell0e027d42021-11-09 14:42:14 +00002788 return {
2789 "error_name": ErrorIf.DilationSmallerOne,
2790 "error_reason": "At least one dilation is smaller than one",
2791 "param_reqs": {"rank": None, "dtype": None, "shape": None},
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002792 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00002793 }
2794
2795 @staticmethod
Matthew Haddonc2025212021-10-08 21:21:05 +01002796 def evScaleTrue(check=False, **kwargs):
2797 error_name = ErrorIf.ScaleTrue
2798 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2799 error_result = False
2800 error_reason = "Scale set to true but input type is INT48"
2801
2802 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002803 input_dtype = kwargs["input_dtype"]
2804 scale32 = kwargs["scale32"]
Matthew Haddonc2025212021-10-08 21:21:05 +01002805 if scale32 and input_dtype == DType.INT48:
2806 error_result = True
2807
2808 info_dict = {
2809 "error_name": error_name,
2810 "error_result": error_result,
2811 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002812 "param_reqs": param_reqs,
Matthew Haddonc2025212021-10-08 21:21:05 +01002813 }
2814 return info_dict
2815
2816 @staticmethod
2817 def evScaleNotTrue(check=False, **kwargs):
2818 error_name = ErrorIf.ScaleNotTrue
2819 param_reqs = {"rank": None, "dtype": None, "shape": None}
2820 error_result = False
2821 error_reason = "Scale set to false but double round set to true"
2822
2823 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002824 scale32 = kwargs["scale32"]
2825 double_round = kwargs["double_round"]
Matthew Haddonc2025212021-10-08 21:21:05 +01002826 if not scale32 and double_round:
2827 error_result = True
2828
2829 info_dict = {
2830 "error_name": error_name,
2831 "error_result": error_result,
2832 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002833 "param_reqs": param_reqs,
Matthew Haddonc2025212021-10-08 21:21:05 +01002834 }
2835 return info_dict
2836
Matthew Haddone807aae2021-10-11 18:12:58 +01002837 @staticmethod
2838 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2839 error_name = ErrorIf.TensorSizeInputOutputMismatch
2840 param_reqs = {"rank": None, "dtype": None, "shape": None}
2841 error_result = False
2842 error_reason = "Input tensor size does not match output tensor size"
2843
2844 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002845 input_shape = kwargs["input_shape"]
2846 output_shape = kwargs["output_shape"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002847 input_size = np.prod(input_shape)
2848 output_size = np.prod(output_shape)
2849 if input_size != output_size:
2850 error_result = True
2851
2852 info_dict = {
2853 "error_name": error_name,
2854 "error_result": error_result,
2855 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002856 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002857 }
2858 return info_dict
2859
2860 @staticmethod
2861 def evStartSmallerZero(check=False, **kwargs):
2862 error_name = ErrorIf.StartSmallerZero
2863 param_reqs = {"rank": None, "dtype": None, "shape": None}
2864 error_result = False
2865 error_reason = "Starting point smaller than zero"
2866
2867 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 input_shape = kwargs["input_shape"]
2869 start = kwargs["start"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002870 rank = len(input_shape)
2871 if len(start) == rank:
2872 for index in range(rank):
2873 if start[index] < 0:
2874 error_result = True
2875
2876 info_dict = {
2877 "error_name": error_name,
2878 "error_result": error_result,
2879 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002880 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002881 }
2882 return info_dict
2883
Matthew Haddone807aae2021-10-11 18:12:58 +01002884 @staticmethod
2885 def evSizeSmallerEqualZero(check=False, **kwargs):
2886 error_name = ErrorIf.SizeSmallerEqualZero
2887 param_reqs = {"rank": None, "dtype": None, "shape": None}
2888 error_result = False
2889 error_reason = "Size smaller than or equal to zero"
2890
2891 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002892 input_shape = kwargs["input_shape"]
2893 size = kwargs["size"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002894 rank = len(input_shape)
2895 if len(size) == rank:
2896 for index in range(rank):
2897 if size[index] <= 0:
2898 error_result = True
2899
2900 info_dict = {
2901 "error_name": error_name,
2902 "error_result": error_result,
2903 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002904 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002905 }
2906 return info_dict
2907
Matthew Haddone807aae2021-10-11 18:12:58 +01002908 @staticmethod
2909 def evStartSizeOutsideBounds(check=False, **kwargs):
2910 error_name = ErrorIf.StartSizeOutsideBounds
2911 param_reqs = {"rank": None, "dtype": None, "shape": None}
2912 error_result = False
2913 error_reason = "starting point plus size larger than input dimension"
2914
2915 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002916 input_shape = kwargs["input_shape"]
2917 start = kwargs["start"]
2918 size = kwargs["size"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002919 rank = len(input_shape)
2920 if len(start) == rank and len(size) == rank:
2921 for index in range(rank):
2922 if start[index] + size[index] > input_shape[index]:
2923 error_result = True
2924
2925 info_dict = {
2926 "error_name": error_name,
2927 "error_result": error_result,
2928 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002929 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002930 }
2931 return info_dict
2932
Matthew Haddone807aae2021-10-11 18:12:58 +01002933 @staticmethod
2934 def evSizeOutputShapeMismatch(check=False, **kwargs):
2935 error_name = ErrorIf.SizeOutputShapeMismatch
2936 param_reqs = {"rank": None, "dtype": None, "shape": None}
2937 error_result = False
2938 error_reason = "Size does not match output dimension"
2939
2940 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002941 input_shape = kwargs["input_shape"]
2942 output_shape = kwargs["output_shape"]
2943 size = kwargs["size"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002944 rank = len(input_shape)
2945 if len(size) == rank:
2946 for index in range(rank):
2947 if size[index] != output_shape[index]:
2948 error_result = True
2949
2950 info_dict = {
2951 "error_name": error_name,
2952 "error_result": error_result,
2953 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002954 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002955 }
2956 return info_dict
2957
2958 @staticmethod
2959 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2960 error_name = ErrorIf.InputSizeStartLengthMismatch
2961 param_reqs = {"rank": None, "dtype": None, "shape": None}
2962 error_result = False
2963 error_reason = "rank of input not equal to length of start or size"
2964
2965 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002966 input_shape = kwargs["input_shape"]
2967 start = kwargs["start"]
2968 size = kwargs["size"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002969 rank = len(input_shape)
2970 if rank != len(start) or rank != len(size):
2971 error_result = True
2972
2973 info_dict = {
2974 "error_name": error_name,
2975 "error_result": error_result,
2976 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002977 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002978 }
2979 return info_dict
2980
2981 @staticmethod
2982 def evIndexOutsideBounds(check=False, **kwargs):
2983 error_name = ErrorIf.IndexOutsideBounds
2984 param_reqs = {"rank": None, "dtype": None, "shape": None}
2985 error_result = False
2986 error_reason = "Index outside of allowed bounds"
2987
2988 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002989 input_shape = kwargs["input_shape"]
2990 perms = kwargs["perms"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002991 rank = len(input_shape)
2992
2993 for index in perms:
2994 if index < 0 or index > rank:
2995 error_result = True
2996
2997 info_dict = {
2998 "error_name": error_name,
2999 "error_result": error_result,
3000 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003001 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01003002 }
3003 return info_dict
3004
3005 @staticmethod
3006 def evIndexUsedTwice(check=False, **kwargs):
3007 error_name = ErrorIf.IndexUsedTwice
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003008 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddone807aae2021-10-11 18:12:58 +01003009 error_result = False
3010 error_reason = "Index used multiple times"
3011
3012 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003013 perms = kwargs["perms"]
Matthew Haddone807aae2021-10-11 18:12:58 +01003014
3015 unique_indices = []
3016 for index in perms:
3017 if index in unique_indices:
3018 error_result = True
3019 else:
3020 unique_indices.append(index)
3021
3022 info_dict = {
3023 "error_name": error_name,
3024 "error_result": error_result,
3025 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003026 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01003027 }
3028 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003029
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003030 @staticmethod
3031 def evMaxSmallerMin(check=False, **kwargs):
3032 error_name = ErrorIf.MaxSmallerMin
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003033 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003034 error_result = False
3035 error_reason = "Max value smaller than min value"
3036
3037 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003038 max_val = kwargs["max_val"]
3039 min_val = kwargs["min_val"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003040 if max_val < min_val:
3041 error_result = True
3042
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003043 info_dict = {
3044 "error_name": error_name,
3045 "error_result": error_result,
3046 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003047 "param_reqs": param_reqs,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003048 }
3049 return info_dict
3050
3051 @staticmethod
3052 def evConcatInputRankMismatch(check=False, **kwargs):
3053 error_name = ErrorIf.ConcatInputRankMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003054 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003055 error_result = False
3056 error_reason = "Input ranks are not identical"
3057
3058 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003059 inputs = kwargs["inputs"]
3060 input_shape = kwargs["input_shape"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003061 for input in inputs:
3062 if len(input.shape) != len(input_shape):
3063 error_result = True
3064
3065 info_dict = {
3066 "error_name": error_name,
3067 "error_result": error_result,
3068 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003069 "param_reqs": param_reqs,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003070 }
3071 return info_dict
3072
3073 @staticmethod
3074 def evConcatInputDimMismatch(check=False, **kwargs):
3075 error_name = ErrorIf.ConcatInputDimMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003076 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003077 error_result = False
3078 error_reason = "Input dimensions differ on too many axes"
3079
3080 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003081 inputs = kwargs["inputs"]
3082 input_shape = kwargs["input_shape"]
3083 axis = kwargs["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003084
3085 # Ensure rank is valid before checking dims.
3086 valid_rank = True
3087 for input in inputs:
3088 if len(input.shape) != len(input_shape):
3089 valid_rank = False
3090
3091 if valid_rank:
3092 for input in inputs:
3093 for i, dim in enumerate(input.shape):
3094 if dim != input_shape[i] and axis != i:
3095 error_result = True
3096
3097 info_dict = {
3098 "error_name": error_name,
3099 "error_result": error_result,
3100 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003101 "param_reqs": param_reqs,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003102 }
3103 return info_dict
3104
Matthew Haddon630c17c2021-10-14 15:05:41 +01003105 @staticmethod
Matthew Haddon01c359d2021-10-15 16:30:48 +01003106 def evConcatShapeSumMismatch(check=False, **kwargs):
3107 error_name = ErrorIf.ConcatShapeSumMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003108 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddon01c359d2021-10-15 16:30:48 +01003109 error_result = False
3110 error_reason = "Sum of dimensions on axis not equal to output dimension"
3111
3112 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003113 inputs = kwargs["inputs"]
3114 input_shape = kwargs["input_shape"]
3115 output_shape = kwargs["output_shape"]
3116 axis = kwargs["axis"]
Matthew Haddon01c359d2021-10-15 16:30:48 +01003117
3118 # Ensure rank is valid before checking dims.
3119 valid_params = True
3120 for input in inputs:
3121 if len(input.shape) != len(input_shape):
3122 valid_params = False
3123 if axis < 0 or axis > len(input_shape):
3124 valid_params = False
3125
3126 if valid_params:
3127 axis_dim_sum = 0
3128 for input in inputs:
3129 axis_dim_sum += input.shape[axis]
3130
3131 if axis_dim_sum != output_shape[axis]:
3132 error_result = True
3133
Matthew Haddon01c359d2021-10-15 16:30:48 +01003134 info_dict = {
3135 "error_name": error_name,
3136 "error_result": error_result,
3137 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003138 "param_reqs": param_reqs,
Matthew Haddon01c359d2021-10-15 16:30:48 +01003139 }
3140 return info_dict
3141
3142 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01003143 def evInputListThenGraphMismatch(check=False, **kwargs):
3144 error_name = ErrorIf.CondIfInputListThenGraphMismatch
3145 param_reqs = {"rank": None, "dtype": None, "shape": None}
3146 error_result = False
3147 error_reason = "Input list shape does not match then-graph shape"
3148
3149 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003150 a = kwargs["a"]
3151 b = kwargs["b"]
3152 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003153 then_block = basicBlocks[1]
3154 then_inputs = then_block.inputs
3155 then_tens = then_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003156 if (a.shape != then_tens[then_inputs[0]].shape) or (
3157 b.shape != then_tens[then_inputs[1]].shape
3158 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003159 error_result = True
3160
3161 info_dict = {
3162 "error_name": error_name,
3163 "error_result": error_result,
3164 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003165 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003166 }
3167 return info_dict
3168
Matthew Haddon630c17c2021-10-14 15:05:41 +01003169 @staticmethod
3170 def evInputListElseGraphMismatch(check=False, **kwargs):
3171 error_name = ErrorIf.CondIfInputListElseGraphMismatch
3172 param_reqs = {"rank": None, "dtype": None, "shape": None}
3173 error_result = False
3174 error_reason = "Input list shape does not match else-graph shape"
3175
3176 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003177 a = kwargs["a"]
3178 b = kwargs["b"]
3179 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003180 else_block = basicBlocks[2]
3181 else_inputs = else_block.inputs
3182 else_tens = else_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003183 if (a.shape != else_tens[else_inputs[0]].shape) or (
3184 b.shape != else_tens[else_inputs[1]].shape
3185 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003186 error_result = True
3187
3188 info_dict = {
3189 "error_name": error_name,
3190 "error_result": error_result,
3191 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003192 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003193 }
3194 return info_dict
3195
Matthew Haddon630c17c2021-10-14 15:05:41 +01003196 @staticmethod
3197 def evOutputListThenGraphMismatch(check=False, **kwargs):
3198 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
3199 param_reqs = {"rank": None, "dtype": None, "shape": None}
3200 error_result = False
3201 error_reason = "Output list shape does not match then-graph shape"
3202
3203 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003204 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003205 cond_block = basicBlocks[0]
3206 cond_outputs = cond_block.outputs
3207 cond_tens = cond_block.tensors
3208 then_block = basicBlocks[1]
3209 then_outputs = then_block.outputs
3210 then_tens = then_block.tensors
3211 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
3212 error_result = True
3213
3214 info_dict = {
3215 "error_name": error_name,
3216 "error_result": error_result,
3217 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003218 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003219 }
3220 return info_dict
3221
Matthew Haddon630c17c2021-10-14 15:05:41 +01003222 @staticmethod
3223 def evOutputListElseGraphMismatch(check=False, **kwargs):
3224 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
3225 param_reqs = {"rank": None, "dtype": None, "shape": None}
3226 error_result = False
3227 error_reason = "Output list shape does not match else-graph shape"
3228
3229 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003230 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003231 cond_block = basicBlocks[0]
3232 cond_outputs = cond_block.outputs
3233 cond_tens = cond_block.tensors
3234 else_block = basicBlocks[2]
3235 else_outputs = else_block.outputs
3236 else_tens = else_block.tensors
3237 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
3238 error_result = True
3239
3240 info_dict = {
3241 "error_name": error_name,
3242 "error_result": error_result,
3243 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003244 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003245 }
3246 return info_dict
3247
Matthew Haddon630c17c2021-10-14 15:05:41 +01003248 @staticmethod
3249 def evInputListOutputListMismatch(check=False, **kwargs):
3250 error_name = ErrorIf.InputListOutputListMismatch
3251 param_reqs = {"rank": None, "dtype": None, "shape": None}
3252 error_result = False
3253 error_reason = "Input list does not match output list"
3254
3255 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003256 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003257 while_block = basicBlocks[0]
3258 while_inputs = while_block.inputs
3259 while_outputs = while_block.outputs
3260 while_tens = while_block.tensors
3261 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
3262 error_result = True
3263
3264 info_dict = {
3265 "error_name": error_name,
3266 "error_result": error_result,
3267 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003268 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003269 }
3270 return info_dict
3271
Matthew Haddon630c17c2021-10-14 15:05:41 +01003272 @staticmethod
3273 def evInputListCondGraphMismatch(check=False, **kwargs):
3274 error_name = ErrorIf.InputListCondGraphMismatch
3275 param_reqs = {"rank": None, "dtype": None, "shape": None}
3276 error_result = False
3277 error_reason = "Input list does not match cond graph"
3278
3279 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003280 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003281 while_block = basicBlocks[0]
3282 while_inputs = while_block.inputs
3283 while_tens = while_block.tensors
3284 cond_block = basicBlocks[1]
3285 cond_inputs = cond_block.inputs
3286 cond_tens = cond_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003287 if (
3288 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
3289 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003290 error_result = True
3291
3292 info_dict = {
3293 "error_name": error_name,
3294 "error_result": error_result,
3295 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003296 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003297 }
3298 return info_dict
3299
Matthew Haddon630c17c2021-10-14 15:05:41 +01003300 @staticmethod
3301 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
3302 error_name = ErrorIf.InputListBodyGraphInputMismatch
3303 param_reqs = {"rank": None, "dtype": None, "shape": None}
3304 error_result = False
3305 error_reason = "Input list does not match body graph input"
3306
3307 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003308 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003309 while_block = basicBlocks[0]
3310 while_inputs = while_block.inputs
3311 while_tens = while_block.tensors
3312 body_block = basicBlocks[2]
3313 body_outputs = body_block.inputs
3314 body_tens = body_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003315 if (
3316 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
3317 ) or (
3318 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
3319 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003320 error_result = True
3321
3322 info_dict = {
3323 "error_name": error_name,
3324 "error_result": error_result,
3325 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003326 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003327 }
3328 return info_dict
3329
Matthew Haddon630c17c2021-10-14 15:05:41 +01003330 @staticmethod
3331 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
3332 error_name = ErrorIf.InputListBodyGraphOutputMismatch
3333 param_reqs = {"rank": None, "dtype": None, "shape": None}
3334 error_result = False
3335 error_reason = "Input list does not match body graph output"
3336
3337 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003338 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003339 while_block = basicBlocks[0]
3340 while_inputs = while_block.inputs
3341 while_tens = while_block.tensors
3342 body_block = basicBlocks[2]
3343 body_outputs = body_block.outputs
3344 body_tens = body_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003345 if (
3346 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
3347 ) or (
3348 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
3349 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003350 error_result = True
3351 info_dict = {
3352 "error_name": error_name,
3353 "error_result": error_result,
3354 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003355 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003356 }
3357 return info_dict
3358
Matthew Haddon630c17c2021-10-14 15:05:41 +01003359 @staticmethod
3360 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
3361 error_name = ErrorIf.CondGraphOutputNotMatchingBool
3362 param_reqs = {"rank": None, "dtype": None, "shape": None}
3363 error_result = False
3364 error_reason = "Cond graph output is not a match list of booleans"
3365
3366 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003367 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003368 cond_block = basicBlocks[1]
3369 cond_outputs = cond_block.outputs
3370 cond_tens = cond_block.tensors
3371 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
3372 error_result = True
3373
3374 info_dict = {
3375 "error_name": error_name,
3376 "error_result": error_result,
3377 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003378 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003379 }
3380 return info_dict
3381
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003382
Matthew Haddonb724efc2021-08-25 16:40:29 +01003383class TosaInvalidValidator:
Matthew Haddonb724efc2021-08-25 16:40:29 +01003384 @staticmethod
3385 def ivWrongDataTypeOrModeResize(**kwargs):
3386 input_dtype = kwargs["input_dtype"]
3387 args = kwargs["args"]
3388 mode = args[0]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003389 output_dtype = args[8]
3390
3391 if mode == ResizeMode.BILINEAR:
3392 # Invalid output data type / Invalid input datatype
3393 return (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003394 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
3395 or not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
3396 or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
3397 or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
Matthew Haddonb724efc2021-08-25 16:40:29 +01003398 )
3399 elif mode == ResizeMode.NEAREST:
3400 # Invalid output data type / Invalid input datatype
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003401 return (input_dtype != output_dtype) or (
3402 input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003403 )
3404 else:
3405 # Invalid resize mode
3406 return True
3407
3408 @staticmethod
3409 def ivBadStride(**kwargs):
3410 input_dtype = kwargs["input_dtype"]
3411 args = kwargs["args"]
3412 stride_x = args[1][0]
3413 stride_y = args[1][1]
3414 stride_fp_x = args[4][0]
3415 stride_fp_y = args[4][1]
3416
3417 if input_dtype == DType.FLOAT:
3418 if stride_fp_x <= 0 or stride_fp_y <= 0:
3419 # Negative or zero stride
3420 return True
3421 else:
3422 if stride_x <= 0 or stride_y <= 0:
3423 # Negative or zero stride
3424 return True
3425 return False
3426
Matthew Haddonb724efc2021-08-25 16:40:29 +01003427 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003428 def ivHeightWidthInvalid(**kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003429 opName = kwargs["opName"]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003430
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003431 inputShapes = kwargs["shapeList"]
Les Bell0e027d42021-11-09 14:42:14 +00003432 input_shape = inputShapes[0]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003433
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003434 args = kwargs["args"]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003435 strides = args[0]
3436 padding = args[1]
Les Bell0e027d42021-11-09 14:42:14 +00003437
Matthew Haddonb724efc2021-08-25 16:40:29 +01003438 if opName.endswith("pool2d"):
Les Bell0e027d42021-11-09 14:42:14 +00003439 # avg_pool2d, max_pool2d
3440 kernel_shape = args[2]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003441 h = (
3442 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
3443 ) // strides[0]
3444 w = (
3445 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
3446 ) // strides[1]
Les Bell0e027d42021-11-09 14:42:14 +00003447 # return True if any dimension is < 1
3448 return h < 1 or w < 1
Matthew Haddonb724efc2021-08-25 16:40:29 +01003449
Les Bell0e027d42021-11-09 14:42:14 +00003450 if opName.startswith("transpose_conv2d"):
3451 # transpose_conv2d
3452 dilations = args[2]
3453 output_shape = args[3]
3454 filter_shape = inputShapes[1]
3455 kernel_shape = filter_shape[1:-1]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003456
Les Bell0e027d42021-11-09 14:42:14 +00003457 def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
3458 """Calculate the transpose_conv2d output size for a dimension.
Matthew Haddonb724efc2021-08-25 16:40:29 +01003459
Les Bell0e027d42021-11-09 14:42:14 +00003460 Based on the keras function deconv_output_length, in
3461 https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
Matthew Haddonb724efc2021-08-25 16:40:29 +01003462
Les Bell0e027d42021-11-09 14:42:14 +00003463 Args:
3464 in_size: the input size - int
3465 stride: the stride - int
3466 kernel_size: the kernel size - int
3467 dilation: the kernel dilation - int
3468 out_pad: the output padding - int
3469 in_pad: the input padding - int
3470
3471 Returns:
3472 the output size
3473 """
3474 dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003475 return (
3476 (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
3477 )
Les Bell0e027d42021-11-09 14:42:14 +00003478
3479 for pad_h, pad_w in (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003480 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
3481 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
3482 (0, 0), # VALID padding
Les Bell0e027d42021-11-09 14:42:14 +00003483 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003484 h = get_out_size(
3485 input_shape[1],
3486 strides[0],
3487 kernel_shape[0],
3488 dilations[0],
3489 padding[0],
3490 pad_h,
3491 )
3492 w = get_out_size(
3493 input_shape[2],
3494 strides[1],
3495 kernel_shape[1],
3496 dilations[1],
3497 padding[1],
3498 pad_w,
3499 )
Les Bell0e027d42021-11-09 14:42:14 +00003500 if output_shape[1] == h and output_shape[2] == w:
3501 return False
3502
3503 # output shape does not match the expected shape for any padding option
Matthew Haddonb724efc2021-08-25 16:40:29 +01003504 return True
Les Bell0e027d42021-11-09 14:42:14 +00003505
3506 if "conv2d" in opName or "conv3d" in opName:
3507 # conv2d, conv3d, depthwise_conv2d
3508 dilations = args[2]
3509 filter_shape = inputShapes[1]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003510 kernel_shape = (
3511 filter_shape[0:2]
3512 if opName.startswith("depthwise_conv2d")
3513 else filter_shape[1:-1]
3514 )
Les Bell0e027d42021-11-09 14:42:14 +00003515
3516 for i in range(len(kernel_shape)):
3517 dim = (
3518 input_shape[i + 1]
3519 - kernel_shape[i]
3520 - (kernel_shape[i] - 1) * (dilations[i] - 1)
3521 + padding[i * 2 + 0]
3522 + padding[i * 2 + 1]
3523 ) // strides[i] + 1
3524 # return True if any dimension is < 1
3525 if dim < 1:
3526 return True
3527 return False
3528
3529 assert False, f"Unrecognized Op: {opName}"
Matthew Haddonb724efc2021-08-25 16:40:29 +01003530
3531 @staticmethod
3532 def ivNonPositiveOutputShape(**kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003533 args = kwargs["args"]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003534 output_shape = args[3]
3535 if output_shape[1] <= 0 or output_shape[2] <= 0:
3536 # Negative output shape
3537 return True
3538 return False
3539
3540
Eric Kunzee5e26762020-10-13 16:11:07 -07003541class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003542 # Maximum rank of tensor supported by test generator.
3543 TOSA_TENSOR_MAX_RANK = 6
3544
Eric Kunzee5e26762020-10-13 16:11:07 -07003545 def __init__(self, args):
3546 self.args = args
3547 self.basePath = args.output_dir
3548 self.random_seed = args.random_seed
3549 self.ser = None
3550 self.rng = np.random.default_rng(self.random_seed)
3551 self.createDynamicOpLists()
3552 self.initOpListDefaults()
3553 self.quantGen = TosaQuantGen()
3554 # Force makeShape to do a specific starting shape
3555 self.targetted_shape = None
3556
3557 def createSerializer(self, opName, testPath):
3558 self.testPath = os.path.join(opName, testPath)
3559
3560 fullPath = os.path.join(self.basePath, self.testPath)
3561 os.makedirs(fullPath, exist_ok=True)
3562 self.ser = ts.TosaSerializer(fullPath)
3563
3564 def getSerializer(self):
3565 return self.ser
3566
3567 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003568 with open(
3569 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
3570 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07003571 fd.write(self.ser.serialize())
3572
Kevin Cheng550ccc52021-03-03 11:21:43 -08003573 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
3574 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07003575
Matthew Haddon74567092021-07-16 15:38:20 +01003576 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003577 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +01003578 seed = self.random_seed + 1
3579 self.rng = np.random.default_rng(seed)
3580
Eric Kunzee5e26762020-10-13 16:11:07 -07003581 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07003582 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -07003583 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07003584 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003585 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003586 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003587 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003588 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
3589 elif dtype == DType.UINT8:
3590 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003591 elif dtype == DType.INT16:
3592 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
3593 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003594 return np.int32(
3595 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
3596 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003597 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003598 return np.int64(
3599 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
3600 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003601 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003602 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003603 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003604 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003605
Kevin Cheng989cb052021-04-28 16:29:44 -07003606 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003607 placeholders = []
3608
Kevin Cheng989cb052021-04-28 16:29:44 -07003609 assert len(shape_list) == len(dtype_list)
3610
3611 for idx, shape in enumerate(shape_list):
3612 arr = self.getRandTensor(shape, dtype_list[idx])
3613 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003614
3615 return placeholders
3616
Kevin Cheng989cb052021-04-28 16:29:44 -07003617 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003618 consts = []
3619
Kevin Cheng989cb052021-04-28 16:29:44 -07003620 assert len(shape_list) == len(dtype_list)
3621
3622 for idx, shape in enumerate(shape_list):
3623 arr = self.getRandTensor(shape, dtype_list[idx])
3624 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003625
3626 return consts
3627
3628 def makeShape(self, rank):
3629 if self.targetted_shape:
3630 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003631 return np.int32(
3632 self.rng.integers(
3633 low=self.args.tensor_shape_range[0],
3634 high=self.args.tensor_shape_range[1],
3635 size=rank,
3636 )
3637 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003638
3639 def setTargetShape(self, shape):
3640 self.targetted_shape = shape
3641
3642 def randInt(self, low=0, high=256):
3643 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
3644
3645 def getRandNumberDType(self, dtype):
3646 if dtype == DType.FLOAT:
3647 return self.rng.random()
3648 elif dtype == DType.BOOL:
3649 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07003650 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003651 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003652 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07003653 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003654 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07003655 elif dtype == DType.INT16:
3656 low, high = (-32768, 32768)
3657 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003658 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07003659 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003660 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07003661 # Special size
3662 return np.int64(self.rng.integers(low, high, size=1))[0]
3663 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003664 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003665
3666 return np.int32(self.rng.integers(low, high, size=1))[0]
3667
3668 def shapeStr(self, shape):
3669
3670 sStr = []
3671 # Convert to strings
3672 for i in shape:
3673 sStr.append(str(i))
3674
Kevin Cheng550ccc52021-03-03 11:21:43 -08003675 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003676
3677 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07003678 if isinstance(t, list):
3679 assert len(t) >= 2
3680 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003681 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003682 if t == DType.BOOL:
3683 return "b"
3684 elif t == DType.INT4:
3685 return "i4"
3686 elif t == DType.INT8:
3687 return "i8"
3688 elif t == DType.UINT8:
3689 return "u8"
3690 elif t == DType.INT16:
3691 return "i16"
3692 elif t == DType.INT32:
3693 return "i32"
3694 elif t == DType.INT48:
3695 return "i48"
3696 elif t == DType.FLOAT:
3697 return "float"
3698 else:
3699 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07003700
3701 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003702 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08003703 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07003704 return 4
3705 elif t == DType.INT8:
3706 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08003707 elif t == DType.UINT8:
3708 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07003709 elif t == DType.INT16:
3710 return 16
3711 elif t == DType.INT32:
3712 return 32
3713 elif t == DType.INT48:
3714 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01003715 elif t == DType.FLOAT:
3716 return 32
3717 elif t == DType.BOOL:
3718 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003719 else:
Les Bell729b0352021-11-24 10:28:21 +00003720 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -07003721
3722 # Argument generators
3723 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
3724 # Where the string descriptor is used to generate the test name and
3725 # The build_fcn_arg_list is expanded and passed to the operator test
3726 # build function
3727
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003728 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
3729 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
3730
Matthew Haddon848efb42021-09-09 12:30:53 +01003731 # build_placeholder returns an int, ABS/other ops does not
3732 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003733 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
3734 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003735 elif op["op"] == Op.IDENTITY:
3736 self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo)
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003737 return result_tens
3738
3739 # Ensure new output type has correct qinfo
3740 if error_name == ErrorIf.WrongOutputType:
3741 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
3742 qinfo = ts.TosaSerializerQuantInfo()
3743 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003744 TosaQuantGen.getQinfo(self, a.dtype),
3745 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003746 )
3747
3748 # Invalidate Input/Output list for error if checks.
3749 input_list = [a.name]
3750 output_list = [result_tens.name]
3751 pCount, cCount = op["operands"]
3752 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003753 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3754 self, error_name, input_list, output_list
3755 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003756
Les Bell729b0352021-11-24 10:28:21 +00003757 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003758 self.ser,
3759 validator_fcns,
3760 error_name,
3761 op=op,
3762 input_dtype=a.dtype,
3763 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003764 qinfo=qinfo,
3765 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003766 input_list=input_list,
3767 output_list=output_list,
3768 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003769 ):
3770 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003771
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003772 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003773 return result_tens
3774
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003775 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003776 result_tens = OutputShaper.binaryBroadcastOp(
3777 self.ser, self.rng, a, b, error_name
3778 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003779
3780 # Invalidate Input/Output list for error if checks.
3781 input_list = [a.name, b.name]
3782 output_list = [result_tens.name]
3783 pCount, cCount = op["operands"]
3784 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003785 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3786 self, error_name, input_list, output_list
3787 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003788
Les Bell729b0352021-11-24 10:28:21 +00003789 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003790 self.ser,
3791 validator_fcns,
3792 error_name,
3793 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003794 input1=a,
3795 input2=b,
3796 input_dtype=a.dtype,
3797 output_dtype=result_tens.dtype,
3798 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003799 input_list=input_list,
3800 output_list=output_list,
3801 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003802 ):
3803 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003804
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003805 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003806 return result_tens
3807
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003808 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003809 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003810 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003811 return result_tens
3812
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003813 def build_arithmetic_right_shift(
3814 self, op, a, b, round, validator_fcns=None, error_name=None
3815 ):
3816 result_tens = OutputShaper.binaryBroadcastOp(
3817 self.ser, self.rng, a, b, error_name
3818 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003819
3820 # Invalidate Input/Output list for error if checks.
3821 input_list = [a.name, b.name]
3822 output_list = [result_tens.name]
3823 pCount, cCount = op["operands"]
3824 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003825 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3826 self, error_name, input_list, output_list
3827 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003828
Les Bell729b0352021-11-24 10:28:21 +00003829 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003830 self.ser,
3831 validator_fcns,
3832 error_name,
3833 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003834 input1=a,
3835 input2=b,
3836 input_dtype=a.dtype,
3837 output_dtype=result_tens.dtype,
3838 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003839 input_list=input_list,
3840 output_list=output_list,
3841 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003842 ):
3843 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -08003844
3845 attr = ts.TosaSerializerAttribute()
3846 attr.ArithmeticRightShiftAttribute(round)
3847
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003848 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08003849 return result_tens
3850
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003851 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003852 result_tens = OutputShaper.binaryBroadcastOp(
3853 self.ser, self.rng, a, b, error_name
3854 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003855
3856 # Special for multiply:
3857 # Force the result to INT32 for INT types
3858 if a.dtype != DType.FLOAT:
3859 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003860 if error_name == ErrorIf.WrongOutputType:
3861 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
3862 outputDType = self.rng.choice(all_dtypes)
3863 result_tens.setDtype(outputDType)
3864
3865 # Invalidate Input/Output list for error if checks.
3866 input_list = [a.name, b.name]
3867 output_list = [result_tens.name]
3868 pCount, cCount = op["operands"]
3869 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003870 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3871 self, error_name, input_list, output_list
3872 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003873
Les Bell729b0352021-11-24 10:28:21 +00003874 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003875 self.ser,
3876 validator_fcns,
3877 error_name,
3878 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003879 input1=a,
3880 input2=b,
3881 input_dtype=a.dtype,
3882 output_dtype=result_tens.dtype,
3883 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003884 input_list=input_list,
3885 output_list=output_list,
3886 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003887 ):
3888 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07003889
Kevin Chengaee1fac2020-11-11 13:54:06 -08003890 attr = ts.TosaSerializerAttribute()
3891 attr.MulAttribute(shift)
3892
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003893 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003894 return result_tens
3895
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003896 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
3897 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003898
Kevin Chengfe392ce2021-10-18 21:51:55 +00003899 attr = ts.TosaSerializerAttribute()
3900 attr.TableAttribute(table)
3901
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003902 # Invalidate Input/Output list for error if checks.
3903 input_list = [a.name]
3904 output_list = [result_tens.name]
3905 pCount, cCount = op["operands"]
3906 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003907 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3908 self, error_name, input_list, output_list
3909 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003910
Les Bell729b0352021-11-24 10:28:21 +00003911 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003912 self.ser,
3913 validator_fcns,
3914 error_name,
3915 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003916 input_shape=a.shape,
3917 input_dtype=a.dtype,
3918 output_dtype=result_tens.dtype,
3919 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003920 input_list=input_list,
3921 output_list=output_list,
3922 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003923 ):
3924 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003925
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003926 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003927
3928 return result_tens
3929
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003930 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
3931 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
3932
3933 # Invalidate Input/Output list for error if checks.
3934 input_list = [cond.name, a.name, b.name]
3935 output_list = [result_tens.name]
3936 pCount, cCount = op["operands"]
3937 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003938 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3939 self, error_name, input_list, output_list
3940 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003941
Les Bell729b0352021-11-24 10:28:21 +00003942 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003943 self.ser,
3944 validator_fcns,
3945 error_name,
3946 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003947 input1=cond,
3948 input2=a,
3949 input3=b,
3950 input_shape=a.shape,
3951 input_dtype=a.dtype,
3952 output_dtype=result_tens.dtype,
3953 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003954 input_list=input_list,
3955 output_list=output_list,
3956 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003957 ):
3958 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003959
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003960 self.ser.addOperator(
3961 op["op"],
3962 input_list,
3963 output_list,
3964 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003965 return result_tens
3966
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003967 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003968 result_tens = OutputShaper.binaryComparisonOp(
3969 self.ser, self.rng, a, b, error_name
3970 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003971
3972 # Invalidate Input/Output list for error if checks.
3973 input_list = [a.name, b.name]
3974 output_list = [result_tens.name]
3975 pCount, cCount = op["operands"]
3976 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003977 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3978 self, error_name, input_list, output_list
3979 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003980
Les Bell729b0352021-11-24 10:28:21 +00003981 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003982 self.ser,
3983 validator_fcns,
3984 error_name,
3985 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003986 input1=a,
3987 input2=b,
3988 input_shape=a.shape,
3989 input_dtype=a.dtype,
3990 output_shape=result_tens.shape,
3991 output_dtype=result_tens.dtype,
3992 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003993 input_list=input_list,
3994 output_list=output_list,
3995 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003996 ):
3997 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003998
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003999 self.ser.addOperator(
4000 op["op"],
4001 input_list,
4002 output_list,
4003 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004004 return result_tens
4005
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004006 def build_argmax(self, op, a, axis, validator_fcns, error_name):
4007 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
4008
4009 # Invalidate Input/Output list for error if checks.
4010 input_list = [a.name]
4011 output_list = [result_tens.name]
4012 pCount, cCount = op["operands"]
4013 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004014 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4015 self, error_name, input_list, output_list
4016 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004017
Les Bell729b0352021-11-24 10:28:21 +00004018 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004019 self.ser,
4020 validator_fcns,
4021 error_name,
4022 op=op,
4023 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004024 input_shape=a.shape,
4025 input_dtype=a.dtype,
4026 output_shape=result_tens.shape,
4027 output_dtype=result_tens.dtype,
4028 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004029 input_list=input_list,
4030 output_list=output_list,
4031 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004032 ):
4033 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004034
4035 attr = ts.TosaSerializerAttribute()
4036 attr.AxisAttribute(axis)
4037
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004038 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004039 return result_tens
4040
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004041 def build_pool2d(
4042 self,
4043 op,
4044 input,
4045 stride,
4046 pad,
4047 kernel,
4048 validator_fcns=None,
4049 error_name=None,
4050 qinfo=None,
4051 ):
4052 result_tens = OutputShaper.pool2dOp(
4053 self.ser, self.rng, input, kernel, stride, pad, error_name
4054 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004055
4056 # Ensure new output type has correct qinfo
4057 if error_name == ErrorIf.WrongInputType:
4058 if input.dtype not in [DType.INT8, DType.UINT8]:
4059 qinfo = ts.TosaSerializerQuantInfo()
4060 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004061 TosaQuantGen.getQinfo(self, input.dtype),
4062 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004063 )
4064
4065 # Invalidate Input/Output list for error if checks.
4066 input_list = [input.name]
4067 output_list = [result_tens.name]
4068 pCount, cCount = op["operands"]
4069 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004070 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4071 self, error_name, input_list, output_list
4072 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004073
Les Bell729b0352021-11-24 10:28:21 +00004074 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004075 self.ser,
4076 validator_fcns,
4077 error_name,
4078 op=op,
4079 input_shape=input.shape,
4080 input_dtype=input.dtype,
4081 output_shape=result_tens.shape,
4082 output_dtype=result_tens.dtype,
4083 kernel=kernel,
4084 stride=stride,
4085 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004086 qinfo=qinfo,
4087 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004088 input_list=input_list,
4089 output_list=output_list,
4090 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004091 ):
4092 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004093
4094 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07004095 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07004096
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004097 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004098 return result_tens
4099
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004100 def build_conv2d(
4101 self,
4102 op,
4103 ifm,
4104 filter,
4105 bias,
4106 strides,
4107 padding,
4108 dilations,
4109 validator_fcns=None,
4110 error_name=None,
4111 qinfo=None,
4112 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004113 assert len(padding) == 4
4114 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00004115 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
4116 )
4117
4118 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004119 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
4120 DType.INT8,
4121 DType.UINT8,
4122 ):
Les Bell0e027d42021-11-09 14:42:14 +00004123 qinfo = ts.TosaSerializerQuantInfo()
4124 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004125 TosaQuantGen.getQinfo(self, ifm.dtype),
4126 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +00004127 )
4128
4129 # Invalidate Input/Output list for error_if checks.
4130 input_list = [ifm.name, filter.name, bias.name]
4131 output_list = [result_tens.name]
4132 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004133 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4134 self, error_name, input_list, output_list
4135 )
Les Bell0e027d42021-11-09 14:42:14 +00004136
Les Bell729b0352021-11-24 10:28:21 +00004137 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00004138 self.ser,
4139 validator_fcns,
4140 error_name,
4141 op=op,
4142 input_dtype=ifm.dtype,
4143 weight_dtype=filter.dtype,
4144 output_dtype=result_tens.dtype,
4145 qinfo=qinfo,
4146 input_list=input_list,
4147 num_operands=num_operands,
4148 output_list=output_list,
4149 pad=padding,
4150 stride=strides,
4151 dilation=dilations,
4152 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00004153 ):
4154 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004155
4156 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07004157 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07004158
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004159 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004160 return result_tens
4161
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004162 def build_conv3d(
4163 self,
4164 op,
4165 ifm,
4166 filter,
4167 bias,
4168 strides,
4169 padding,
4170 dilations,
4171 validator_fcns=None,
4172 error_name=None,
4173 qinfo=None,
4174 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004175 assert len(padding) == 6
4176 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +00004177 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
4178 )
4179
4180 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004181 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
4182 DType.INT8,
4183 DType.UINT8,
4184 ):
Les Bell0e027d42021-11-09 14:42:14 +00004185 qinfo = ts.TosaSerializerQuantInfo()
4186 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004187 TosaQuantGen.getQinfo(self, ifm.dtype),
4188 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +00004189 )
4190
4191 # Invalidate Input/Output list for error_if checks.
4192 input_list = [ifm.name, filter.name, bias.name]
4193 output_list = [result_tens.name]
4194 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004195 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4196 self, error_name, input_list, output_list
4197 )
Les Bell0e027d42021-11-09 14:42:14 +00004198
Les Bell729b0352021-11-24 10:28:21 +00004199 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00004200 self.ser,
4201 validator_fcns,
4202 error_name,
4203 op=op,
4204 input_dtype=ifm.dtype,
4205 weight_dtype=filter.dtype,
4206 output_dtype=result_tens.dtype,
4207 qinfo=qinfo,
4208 input_list=input_list,
4209 num_operands=num_operands,
4210 output_list=output_list,
4211 pad=padding,
4212 stride=strides,
4213 dilation=dilations,
4214 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00004215 ):
4216 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07004217
4218 attr = ts.TosaSerializerAttribute()
4219 attr.ConvAttribute(padding, strides, dilations)
4220
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004221 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Kevin Cheng1533b852021-09-01 12:51:58 -07004222 return result_tens
4223
Kevin Cheng550ccc52021-03-03 11:21:43 -08004224 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004225 self,
4226 op,
4227 ifm,
4228 filter,
4229 bias,
4230 stride,
4231 outpad,
4232 dilation,
4233 output_shape,
4234 validator_fcns=None,
4235 error_name=None,
4236 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004237 ):
4238 assert len(outpad) == 2
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004239 result_tens = OutputShaper.transposeConv2DOp(
4240 self.ser, self.rng, ifm, output_shape, error_name
4241 )
Les Bell0e027d42021-11-09 14:42:14 +00004242
4243 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004244 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
4245 DType.INT8,
4246 DType.UINT8,
4247 ):
Les Bell0e027d42021-11-09 14:42:14 +00004248 qinfo = ts.TosaSerializerQuantInfo()
4249 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004250 TosaQuantGen.getQinfo(self, ifm.dtype),
4251 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +00004252 )
4253
4254 # Invalidate Input/Output list for error_if checks.
4255 input_list = [ifm.name, filter.name, bias.name]
4256 output_list = [result_tens.name]
4257 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004258 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4259 self, error_name, input_list, output_list
4260 )
Les Bell0e027d42021-11-09 14:42:14 +00004261
Les Bell729b0352021-11-24 10:28:21 +00004262 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00004263 self.ser,
4264 validator_fcns,
4265 error_name,
4266 op=op,
4267 input_dtype=ifm.dtype,
4268 weight_dtype=filter.dtype,
4269 output_dtype=result_tens.dtype,
4270 qinfo=qinfo,
4271 input_list=input_list,
4272 num_operands=num_operands,
4273 output_list=output_list,
4274 pad=outpad,
4275 stride=stride,
4276 dilation=dilation,
4277 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00004278 ):
4279 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004280
4281 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07004282 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004283
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004284 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004285 return result_tens
4286
Kevin Cheng550ccc52021-03-03 11:21:43 -08004287 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004288 self,
4289 op,
4290 ifm,
4291 filter,
4292 bias,
4293 strides,
4294 padding,
4295 dilations,
4296 validator_fcns=None,
4297 error_name=None,
4298 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004299 ):
4300 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00004301 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
4302 )
4303
4304 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004305 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
4306 DType.INT8,
4307 DType.UINT8,
4308 ):
Les Bell0e027d42021-11-09 14:42:14 +00004309 qinfo = ts.TosaSerializerQuantInfo()
4310 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004311 TosaQuantGen.getQinfo(self, ifm.dtype),
4312 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +00004313 )
4314
4315 # Invalidate Input/Output list for error_if checks.
4316 input_list = [ifm.name, filter.name, bias.name]
4317 output_list = [result_tens.name]
4318 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004319 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4320 self, error_name, input_list, output_list
4321 )
Les Bell0e027d42021-11-09 14:42:14 +00004322
Les Bell729b0352021-11-24 10:28:21 +00004323 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00004324 self.ser,
4325 validator_fcns,
4326 error_name,
4327 op=op,
4328 input_dtype=ifm.dtype,
4329 weight_dtype=filter.dtype,
4330 output_dtype=result_tens.dtype,
4331 qinfo=qinfo,
4332 input_list=input_list,
4333 num_operands=num_operands,
4334 output_list=output_list,
4335 pad=padding,
4336 stride=strides,
4337 dilation=dilations,
4338 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00004339 ):
4340 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004341
4342 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07004343 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07004344
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004345 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004346 return result_tens
4347
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004348 def build_fully_connected(
4349 self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
4350 ):
4351 result_tens = OutputShaper.fullyConnectedOp(
4352 self.ser, self.rng, ifm, filter, error_name
4353 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004354
4355 # Invalidate Input/Output list for error if checks.
4356 input_list = [ifm.name, filter.name, bias.name]
4357 output_list = [result_tens.name]
4358 pCount, cCount = op["operands"]
4359 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004360 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4361 self, error_name, input_list, output_list
4362 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004363
Les Bell729b0352021-11-24 10:28:21 +00004364 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004365 self.ser,
4366 validator_fcns,
4367 error_name,
4368 op=op,
4369 input_shape=ifm.shape,
4370 input_dtype=ifm.dtype,
4371 weight_dtype=filter.dtype,
4372 output_shape=result_tens.shape,
4373 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004374 qinfo=qinfo,
4375 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004376 input_list=input_list,
4377 output_list=output_list,
4378 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004379 ):
4380 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004381
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004382 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004383 return result_tens
4384
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004385 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
4386 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
4387
4388 # Invalidate Input/Output list for error if checks.
4389 input_list = [a.name, b.name]
4390 output_list = [result_tens.name]
4391 pCount, cCount = op["operands"]
4392 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004393 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4394 self, error_name, input_list, output_list
4395 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004396
Les Bell729b0352021-11-24 10:28:21 +00004397 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004398 self.ser,
4399 validator_fcns,
4400 error_name,
4401 op=op,
4402 input_shape=a.shape,
4403 input_dtype=a.dtype,
4404 input2_shape=b.shape,
4405 input2_dtype=b.dtype,
4406 output_shape=result_tens.shape,
4407 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004408 qinfo=qinfo,
4409 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004410 input_list=input_list,
4411 output_list=output_list,
4412 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004413 ):
4414 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004415
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004416 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004417 return result_tens
4418
Matthew Haddond6ce7252021-09-29 15:35:44 +01004419 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
4420 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
4421
4422 # Invalidate Input/Output list for error if checks.
4423 input_list = [a.name]
4424 output_list = [result_tens.name]
4425 pCount, cCount = op["operands"]
4426 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004427 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4428 self, error_name, input_list, output_list
4429 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01004430
Les Bell729b0352021-11-24 10:28:21 +00004431 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01004432 self.ser,
4433 validator_fcns,
4434 error_name,
4435 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004436 axis=axis,
4437 input_shape=a.shape,
4438 output_shape=result_tens.shape,
4439 input_dtype=a.dtype,
4440 output_dtype=result_tens.dtype,
4441 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004442 input_list=input_list,
4443 output_list=output_list,
4444 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004445 ):
4446 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004447
4448 attr = ts.TosaSerializerAttribute()
4449 attr.AxisAttribute(axis)
4450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004451 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004452 return result_tens
4453
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004454 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
4455 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004456
Jeremy Johnson18e26662021-07-22 16:15:29 +01004457 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07004458
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004459 if error_name == ErrorIf.MaxSmallerMin:
4460 # Make sure the numbers are different to invoke this error
4461 while v[0] == v[1]:
4462 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
4463 max_val = min(v)
4464 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004465 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004466 max_val = max(v)
4467 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004468
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004469 # Invalidate Input/Output list for error if checks.
4470 input_list = [a.name]
4471 output_list = [result_tens.name]
4472 pCount, cCount = op["operands"]
4473 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004474 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4475 self, error_name, input_list, output_list
4476 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004477
Les Bell729b0352021-11-24 10:28:21 +00004478 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004479 self.ser,
4480 validator_fcns,
4481 error_name,
4482 op=op,
4483 max_val=max_val,
4484 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004485 input_shape=a.shape,
4486 output_shape=result_tens.shape,
4487 input_dtype=a.dtype,
4488 output_dtype=result_tens.dtype,
4489 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004490 input_list=input_list,
4491 output_list=output_list,
4492 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004493 ):
4494 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004495
4496 attr = ts.TosaSerializerAttribute()
4497 if a.dtype == DType.FLOAT:
4498 attr.ClampAttribute(0, 0, min_val, max_val)
4499 else:
4500 attr.ClampAttribute(min_val, max_val, 0, 0)
4501
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004502 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004503 return result_tens
4504
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004505 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
4506 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004507 attr = ts.TosaSerializerAttribute()
4508
4509 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
4510
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004511 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004512 return result_tens
4513
4514 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004515 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
4516 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004517
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004518 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004519 return result_tens
4520
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004521 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
4522 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4523
4524 # Invalidate Input/Output list for error if checks.
4525 input_list = [a.name]
4526 output_list = [result_tens.name]
4527 pCount, cCount = op["operands"]
4528 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004529 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4530 self, error_name, input_list, output_list
4531 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004532
Les Bell729b0352021-11-24 10:28:21 +00004533 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004534 self.ser,
4535 validator_fcns,
4536 error_name,
4537 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004538 input_shape=a.shape,
4539 output_shape=result_tens.shape,
4540 input_dtype=a.dtype,
4541 output_dtype=result_tens.dtype,
4542 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004543 input_list=input_list,
4544 output_list=output_list,
4545 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004546 ):
4547 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004548
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004549 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004550 return result_tens
4551
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004552 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
4553 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4554
4555 # Invalidate Input/Output list for error if checks.
4556 input_list = [a.name]
4557 output_list = [result_tens.name]
4558 pCount, cCount = op["operands"]
4559 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004560 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4561 self, error_name, input_list, output_list
4562 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004563
Les Bell729b0352021-11-24 10:28:21 +00004564 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004565 self.ser,
4566 validator_fcns,
4567 error_name,
4568 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004569 input_shape=a.shape,
4570 output_shape=result_tens.shape,
4571 input_dtype=a.dtype,
4572 output_dtype=result_tens.dtype,
4573 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004574 input_list=input_list,
4575 output_list=output_list,
4576 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004577 ):
4578 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004579
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004580 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004581 return result_tens
4582
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004583 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
4584 if error_name != ErrorIf.WrongInputType:
4585 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01004586
4587 # To store variable length list of input tensors we need to store axis along with it
4588 axis = a[-1]
4589 a = a[:-1]
4590
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004591 result_tens = OutputShaper.concatOp(
4592 self.ser, self.rng, axis, *a, error_name=error_name
4593 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004594
Matthew Haddon818ab902021-07-27 09:12:49 +01004595 input_tensor_names = []
4596 for tensor in a:
4597 input_tensor_names.append(tensor.name)
4598
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004599 # Invalidate Input/Output list for error if checks.
4600 input_list = input_tensor_names
4601 output_list = [result_tens.name]
4602 pCount, cCount = op["operands"]
4603 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004604 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4605 self, error_name, input_list, output_list
4606 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004607
Les Bell729b0352021-11-24 10:28:21 +00004608 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004609 self.ser,
4610 validator_fcns,
4611 error_name,
4612 op=op,
4613 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004614 input_shape=a[0].shape,
4615 output_shape=result_tens.shape,
4616 input_dtype=a[0].dtype,
4617 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004618 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004619 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004620 input_list=input_list,
4621 output_list=output_list,
4622 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004623 ):
4624 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004625
4626 attr = ts.TosaSerializerAttribute()
4627 attr.AxisAttribute(axis)
4628
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004629 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01004630 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004631
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004632 def build_pad(
4633 self,
4634 op,
4635 a,
4636 padding,
4637 pad_const_int,
4638 pad_const_float,
4639 validator_fcns=None,
4640 error_name=None,
4641 qinfo=None,
4642 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01004643 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004644
Kevin Chengfe392ce2021-10-18 21:51:55 +00004645 attr = ts.TosaSerializerAttribute()
4646 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07004647
Matthew Haddone807aae2021-10-11 18:12:58 +01004648 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004649 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004650 output_list = [result_tens.name]
4651 pCount, cCount = op["operands"]
4652 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004653 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4654 self, error_name, input_list, output_list
4655 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004656
Les Bell729b0352021-11-24 10:28:21 +00004657 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004658 self.ser,
4659 validator_fcns,
4660 error_name,
4661 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004662 input_shape=a.shape,
4663 output_shape=result_tens.shape,
4664 input_dtype=a.dtype,
4665 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01004666 pad=padding,
4667 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004668 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01004669 input_list=input_list,
4670 output_list=output_list,
4671 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004672 ):
4673 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01004674
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004675 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Matthew Haddone86fd342021-09-07 16:12:21 +01004676 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004677
Matthew Haddone807aae2021-10-11 18:12:58 +01004678 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004679 result_tens = OutputShaper.reshapeOp(
4680 self.ser, self.rng, a, newShape, error_name
4681 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004682
4683 # Invalidate Input/Output list for error if checks.
4684 input_list = [a.name]
4685 output_list = [result_tens.name]
4686 pCount, cCount = op["operands"]
4687 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004688 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4689 self, error_name, input_list, output_list
4690 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004691
Les Bell729b0352021-11-24 10:28:21 +00004692 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004693 self.ser,
4694 validator_fcns,
4695 error_name,
4696 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004697 input_shape=a.shape,
4698 output_shape=result_tens.shape,
4699 input_dtype=a.dtype,
4700 output_dtype=result_tens.dtype,
4701 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01004702 input_list=input_list,
4703 output_list=output_list,
4704 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004705 ):
4706 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004707
4708 attr = ts.TosaSerializerAttribute()
4709 attr.ReshapeAttribute(newShape)
4710
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004711 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004712 return result_tens
4713
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004714 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
4715 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4716
4717 # Invalidate Input/Output list for error if checks.
4718 input_list = [a.name]
4719 output_list = [result_tens.name]
4720 pCount, cCount = op["operands"]
4721 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004722 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4723 self, error_name, input_list, output_list
4724 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004725
Les Bell729b0352021-11-24 10:28:21 +00004726 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004727 self.ser,
4728 validator_fcns,
4729 error_name,
4730 op=op,
4731 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004732 input_shape=a.shape,
4733 output_shape=result_tens.shape,
4734 input_dtype=a.dtype,
4735 output_dtype=result_tens.dtype,
4736 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004737 input_list=input_list,
4738 output_list=output_list,
4739 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004740 ):
4741 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004742
4743 attr = ts.TosaSerializerAttribute()
4744 attr.AxisAttribute(axis)
4745
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004746 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004747 return result_tens
4748
Matthew Haddone807aae2021-10-11 18:12:58 +01004749 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
4750 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004751
Kevin Chengfe392ce2021-10-18 21:51:55 +00004752 attr = ts.TosaSerializerAttribute()
4753 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07004754
Matthew Haddone807aae2021-10-11 18:12:58 +01004755 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004756 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004757 output_list = [result_tens.name]
4758 pCount, cCount = op["operands"]
4759 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004760 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4761 self, error_name, input_list, output_list
4762 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004763
Les Bell729b0352021-11-24 10:28:21 +00004764 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004765 self.ser,
4766 validator_fcns,
4767 error_name,
4768 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004769 input_shape=a.shape,
4770 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01004771 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004772 input_dtype=a.dtype,
4773 output_dtype=result_tens.dtype,
4774 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01004775 input_list=input_list,
4776 output_list=output_list,
4777 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004778 ):
4779 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01004780
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004781 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004782 return result_tens
4783
Matthew Haddone807aae2021-10-11 18:12:58 +01004784 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004785 result_tens = OutputShaper.sliceOp(
4786 self.ser, self.rng, a, start, size, error_name
4787 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004788
4789 # Invalidate Input/Output list for error if checks.
4790 input_list = [a.name]
4791 output_list = [result_tens.name]
4792 pCount, cCount = op["operands"]
4793 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004794 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4795 self, error_name, input_list, output_list
4796 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004797
Les Bell729b0352021-11-24 10:28:21 +00004798 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004799 self.ser,
4800 validator_fcns,
4801 error_name,
4802 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004803 input_shape=a.shape,
4804 output_shape=result_tens.shape,
4805 input_dtype=a.dtype,
4806 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01004807 start=start,
4808 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004809 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01004810 input_list=input_list,
4811 output_list=output_list,
4812 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004813 ):
4814 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004815
4816 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01004817 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07004818
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004819 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004820 return result_tens
4821
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004822 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
4823 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
4824
4825 # Invalidate Input/Output list for error if checks.
4826 input_list = [a.name]
4827 output_list = [result_tens.name]
4828 pCount, cCount = op["operands"]
4829 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004830 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4831 self, error_name, input_list, output_list
4832 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004833
Les Bell729b0352021-11-24 10:28:21 +00004834 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004835 self.ser,
4836 validator_fcns,
4837 error_name,
4838 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004839 input_shape=a.shape,
4840 output_shape=result_tens.shape,
4841 input_dtype=a.dtype,
4842 output_dtype=result_tens.dtype,
4843 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004844 input_list=input_list,
4845 output_list=output_list,
4846 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004847 ):
4848 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004849
4850 attr = ts.TosaSerializerAttribute()
4851 attr.TileAttribute(multiples)
4852
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004853 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004854 return result_tens
4855
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004856 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004857
4858 # Create a new indicies tensor
4859 # here with data that doesn't exceed the dimensions of the values tensor
4860
Kevin Cheng550ccc52021-03-03 11:21:43 -08004861 K = values.shape[1] # K
4862 W = self.randInt(
4863 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
4864 ) # W
4865 indicies_arr = np.int32(
4866 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
4867 ) # (N, W)
4868 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004869
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004870 result_tens = OutputShaper.gatherOp(
4871 self.ser, self.rng, values, indicies, error_name
4872 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004873
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004874 # Invalidate Input/Output list for error if checks.
4875 input_list = [values.name, indicies.name]
4876 output_list = [result_tens.name]
4877 pCount, cCount = op["operands"]
4878 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004879 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4880 self, error_name, input_list, output_list
4881 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004882
Les Bell729b0352021-11-24 10:28:21 +00004883 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004884 self.ser,
4885 validator_fcns,
4886 error_name,
4887 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004888 input_shape=values.shape,
4889 output_shape=result_tens.shape,
4890 input_dtype=values.dtype,
4891 output_dtype=result_tens.dtype,
4892 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004893 input_list=input_list,
4894 output_list=output_list,
4895 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004896 ):
4897 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004898
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004899 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004900
4901 return result_tens
4902
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004903 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08004904
4905 # Create a new indicies tensor
4906 # here with data that doesn't exceed the dimensions of the values_in tensor
4907
Kevin Cheng550ccc52021-03-03 11:21:43 -08004908 K = values_in.shape[1] # K
4909 W = input.shape[1] # W
4910 indicies_arr = np.int32(
4911 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
4912 ) # (N, W)
4913 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004914
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004915 result_tens = OutputShaper.scatterOp(
4916 self.ser, self.rng, values_in, indicies, input, error_name
4917 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08004918
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004919 # Invalidate Input/Output list for error if checks.
4920 input_list = [values_in.name, indicies.name, input.name]
4921 output_list = [result_tens.name]
4922 pCount, cCount = op["operands"]
4923 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004924 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4925 self, error_name, input_list, output_list
4926 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004927
Les Bell729b0352021-11-24 10:28:21 +00004928 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004929 self.ser,
4930 validator_fcns,
4931 error_name,
4932 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004933 input_shape=values_in.shape,
4934 output_shape=result_tens.shape,
4935 input_dtype=values_in.dtype,
4936 output_dtype=result_tens.dtype,
4937 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004938 input_list=input_list,
4939 output_list=output_list,
4940 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004941 ):
4942 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08004943
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004944 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004945
Kevin Cheng77d0f762020-11-24 10:26:32 -08004946 return result_tens
4947
Kevin Cheng550ccc52021-03-03 11:21:43 -08004948 def build_resize(
4949 self,
4950 op,
4951 input,
4952 mode,
4953 stride,
4954 offset,
4955 shift,
4956 stride_fp,
4957 offset_fp,
4958 output_dims,
4959 input_dtype,
4960 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004961 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004962 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004963 ):
4964 result_tens = OutputShaper.resizeOp(
4965 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004966 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004967 input,
4968 mode,
4969 stride,
4970 offset,
4971 shift,
4972 stride_fp,
4973 offset_fp,
4974 output_dims,
4975 input_dtype,
4976 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004977 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004978 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004979
Matthew Haddon848efb42021-09-09 12:30:53 +01004980 # Invalidate Input/Output list for error if checks.
4981 input_list = [input.name]
4982 output_list = [result_tens.name]
4983 pCount, cCount = op["operands"]
4984 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004985 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4986 self, error_name, input_list, output_list
4987 )
Matthew Haddone86fd342021-09-07 16:12:21 +01004988
Les Bell729b0352021-11-24 10:28:21 +00004989 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01004990 self.ser,
4991 validator_fcns,
4992 error_name,
4993 op=op,
4994 mode=mode,
4995 shift=shift,
4996 input_dtype=input_dtype,
4997 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004998 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01004999 output_shape=output_dims,
5000 offset=offset,
5001 offset_fp=offset_fp,
5002 stride=stride,
5003 stride_fp=stride_fp,
5004 input_list=input_list,
5005 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005006 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01005007 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00005008 ):
5009 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01005010
Eric Kunzee5e26762020-10-13 16:11:07 -07005011 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08005012
Kevin Cheng550ccc52021-03-03 11:21:43 -08005013 attr.ResizeAttribute(
5014 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
5015 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005016
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005017 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005018 return result_tens
5019
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005020 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
5021 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
5022 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005023 self.ser.addOperator(
5024 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
5025 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005026 return result_tens
5027
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005028 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07005029 self.ser.addOutputTensor(val)
5030 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07005031
5032 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005033 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005034 result_tens = OutputShaper.typeConversionOp(
5035 self.ser, self.rng, val, out_dtype, error_name
5036 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005037
5038 # Invalidate Input/Output list for error if checks.
5039 input_list = [val.name]
5040 output_list = [result_tens.name]
5041 pCount, cCount = op["operands"]
5042 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005043 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
5044 self, error_name, input_list, output_list
5045 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005046
Les Bell729b0352021-11-24 10:28:21 +00005047 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005048 self.ser,
5049 validator_fcns,
5050 error_name,
5051 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005052 input_shape=val.shape,
5053 output_shape=result_tens.shape,
5054 input_dtype=val.dtype,
5055 output_dtype=result_tens.dtype,
5056 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005057 input_list=input_list,
5058 output_list=output_list,
5059 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00005060 ):
5061 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005062
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005063 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07005064 return result_tens
5065
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005066 def build_rescale(
5067 self,
5068 op,
5069 val,
5070 out_dtype,
5071 scale32,
5072 double_round,
5073 per_channel,
5074 validator_fcns,
5075 error_name,
5076 ):
5077 result_tens = OutputShaper.typeConversionOp(
5078 self.ser, self.rng, val, out_dtype, error_name
5079 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005080
5081 if per_channel:
5082 nc = val.shape[-1]
5083 else:
5084 nc = 1
5085
5086 in_type_width = self.typeWidth(val.dtype)
5087 out_type_width = self.typeWidth(out_dtype)
5088
Kevin Cheng3a478572021-01-22 17:21:02 -08005089 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01005090 input_zp = self.randInt(-128, 128)
5091 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07005092 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01005093 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07005094 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01005095 elif error_name == ErrorIf.InputZeroPointNotZero:
5096 input_zp = self.randInt(-128, 128)
5097 if input_zp == 0:
5098 input_zp = input_zp + self.rng.integers(1, 10)
5099 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005100 else:
5101 input_zp = 0
5102
Kevin Cheng3a478572021-01-22 17:21:02 -08005103 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01005104 output_zp = self.randInt(-128, 128)
5105 out_type_width = out_type_width + 1
5106 elif out_dtype == DType.UINT8:
5107 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07005108 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01005109 elif error_name == ErrorIf.OutputZeroPointNotZero:
5110 output_zp = self.randInt(-128, 128)
5111 if output_zp == 0:
5112 output_zp = output_zp + self.rng.integers(1, 10)
5113 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005114 else:
5115 output_zp = 0
5116
5117 # Calculate scale based on:
5118 # scale = a *(2^output_width)/(2^input_width))
5119
5120 a = np.float32(self.rng.random(size=[nc]))
5121 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
5122
5123 if scale32:
5124 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01005125 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07005126 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
5127 else:
5128 # Cap the scaling at 2^15 - 1 for scale16
5129 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
5130
Kevin Cheng550ccc52021-03-03 11:21:43 -08005131 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07005132
5133 multiplier_arr = np.int32(np.zeros(shape=[nc]))
5134 shift_arr = np.int32(np.zeros(shape=[nc]))
5135
5136 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005137 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
5138 scale_arr[i], scale32
5139 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005140
Kevin Cheng550ccc52021-03-03 11:21:43 -08005141 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07005142
Matthew Haddonc2025212021-10-08 21:21:05 +01005143 # Invalidate Input/Output list for error if checks.
5144 input_list = [val.name]
5145 output_list = [result_tens.name]
5146 pCount, cCount = op["operands"]
5147 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005148 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
5149 self, error_name, input_list, output_list
5150 )
Matthew Haddonc2025212021-10-08 21:21:05 +01005151
5152 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00005153 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01005154 self.ser,
5155 validator_fcns,
5156 error_name,
5157 op=op,
5158 input_dtype=val.dtype,
5159 output_dtype=out_dtype,
5160 input_shape=val.shape,
5161 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005162 scale32=scale32,
5163 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01005164 input_list=input_list,
5165 output_list=output_list,
5166 result_tensor=result_tens,
5167 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00005168 ):
5169 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01005170
Eric Kunzee5e26762020-10-13 16:11:07 -07005171 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005172 attr.RescaleAttribute(
5173 input_zp,
5174 output_zp,
5175 multiplier_arr,
5176 shift_arr,
5177 scale32,
5178 double_round,
5179 per_channel,
5180 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005181
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005182 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005183 return result_tens
5184
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005185 def build_cond_if_const(
5186 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
5187 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005188 # For cond_if with constants, we're supplied with then/else tensors that we ignore
5189 # (except for the generated shap) and the condition. Build Then/Else blocks
5190 # and fill them with const nodes for the body.
5191
5192 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08005193 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07005194
5195 # Make then/else tensors
5196 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01005197
5198 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005199 if error_name in [
5200 ErrorIf.CondIfOutputListThenGraphMismatch,
5201 ErrorIf.CondIfOutputListElseGraphMismatch,
5202 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01005203 incorrect_shape = deepcopy(then_tens.shape)
5204 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005205 incorrect_shape[i] += (
5206 self.rng.choice([-3, -2, 2, 3])
5207 if incorrect_shape[i] > 3
5208 else self.rng.choice([1, 2, 4])
5209 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01005210 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
5211
Jeremy Johnson18e26662021-07-22 16:15:29 +01005212 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
5213 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07005214
5215 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08005216 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07005217
5218 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08005219 then_block = "THEN_BLOCK"
5220 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07005221 attr = ts.TosaSerializerAttribute()
5222 attr.CondIfAttribute(then_block, else_block)
5223
5224 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005225 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005226
5227 self.ser.startBasicBlock(then_block)
5228 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01005229 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
5230 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
5231 else:
5232 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005233 self.ser.addOutputTensor(then_tens)
5234
5235 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01005236 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
5237 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
5238 else:
5239 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005240 self.ser.addOutputTensor(else_tens)
5241
Les Bell729b0352021-11-24 10:28:21 +00005242 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01005243 self.ser,
5244 validator_fcns,
5245 error_name,
5246 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005247 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00005248 ):
5249 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01005250
Eric Kunzee5e26762020-10-13 16:11:07 -07005251 return result_tens
5252
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005253 def build_cond_if_binary(
5254 self, op, a, b, cond, validator_fcns=None, error_name=None
5255 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005256 # For cond_if with a binary op in the then/else blocks, take a and b and
5257 # alternately add or subtract them based on the condition
5258
5259 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08005260 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07005261
Kevin Cheng550ccc52021-03-03 11:21:43 -08005262 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005263
5264 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08005265 then_block = "THEN_BLOCK"
5266 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07005267 attr = ts.TosaSerializerAttribute()
5268 attr.CondIfAttribute(then_block, else_block)
5269
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005270 if error_name in [
5271 ErrorIf.CondIfInputListThenGraphMismatch,
5272 ErrorIf.CondIfInputListElseGraphMismatch,
5273 ErrorIf.CondIfOutputListElseGraphMismatch,
5274 ErrorIf.CondIfOutputListThenGraphMismatch,
5275 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01005276 incorrect_shape = a.shape.copy()
5277 for i in range(len(incorrect_shape)):
5278 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
5279 incorrect_block_input = deepcopy(a)
5280 incorrect_block_input.shape = incorrect_shape
5281
Eric Kunzee5e26762020-10-13 16:11:07 -07005282 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08005283 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005284 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08005285 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005286
Les Bell6040b4d2021-10-11 12:50:31 +01005287 if a.dtype in (DType.FLOAT, DType.INT32):
5288 then_op, else_op = Op.ADD, Op.SUB
5289 elif a.dtype in (DType.INT8, DType.INT16):
5290 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
5291 else:
5292 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07005293
Les Bell6040b4d2021-10-11 12:50:31 +01005294 for block, op in ((then_block, then_op), (else_block, else_op)):
5295 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005296 if (
5297 error_name == ErrorIf.CondIfInputListThenGraphMismatch
5298 and block == then_block
5299 ) or (
5300 error_name == ErrorIf.CondIfInputListElseGraphMismatch
5301 and block == else_block
5302 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01005303 self.ser.addInputTensor(incorrect_block_input)
5304 self.ser.addInputTensor(b)
5305 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005306 elif (
5307 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
5308 and block == then_block
5309 ) or (
5310 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
5311 and block == else_block
5312 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01005313 self.ser.addInputTensor(a)
5314 self.ser.addInputTensor(b)
5315 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
5316 else:
5317 self.ser.addInputTensor(a)
5318 self.ser.addInputTensor(b)
5319 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01005320 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07005321
Les Bell729b0352021-11-24 10:28:21 +00005322 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01005323 self.ser,
5324 validator_fcns,
5325 error_name,
5326 op=op,
5327 a=a,
5328 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005329 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00005330 ):
5331 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01005332
Eric Kunzee5e26762020-10-13 16:11:07 -07005333 return result_tens
5334
Matthew Haddon630c17c2021-10-14 15:05:41 +01005335 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005336 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07005337
Kevin Cheng550ccc52021-03-03 11:21:43 -08005338 cond_block = "COND_BLOCK"
5339 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07005340
5341 attr = ts.TosaSerializerAttribute()
5342 attr.WhileLoopAttribute(cond_block, body_block)
5343
5344 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08005345 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005346 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08005347 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07005348
5349 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08005350 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
5351 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01005352 if error_name == ErrorIf.InputListOutputListMismatch:
5353 incorrect_acc = deepcopy(acc)
5354 for i in range(len(incorrect_acc.shape)):
5355 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
5356 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
5357 else:
5358 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005359
5360 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08005361 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005362 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08005363 [iter.name, a.name, acc.name],
5364 [iter_out.name, a_out.name, acc_out.name],
5365 attr,
5366 )
Kevin Chengb227ae52021-09-02 13:43:17 -07005367 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07005368
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005369 if error_name in [
5370 ErrorIf.InputListCondGraphMismatch,
5371 ErrorIf.InputListBodyGraphInputMismatch,
5372 ErrorIf.InputListBodyGraphOutputMismatch,
5373 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01005374 incorrect_iter = deepcopy(iter)
5375 for i in range(len(incorrect_iter.shape)):
5376 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
5377 if len(incorrect_iter.shape) == 0:
5378 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
5379
5380 incorrect_acc = deepcopy(acc)
5381 for i in range(len(incorrect_acc.shape)):
5382 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
5383
Eric Kunzee5e26762020-10-13 16:11:07 -07005384 # COND block (input: iter, output: cond_tens )
5385 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01005386 if error_name == ErrorIf.InputListCondGraphMismatch:
5387 self.ser.addInputTensor(incorrect_iter)
5388 self.ser.addInputTensor(a)
5389 self.ser.addInputTensor(incorrect_acc)
5390 else:
5391 self.ser.addInputTensor(iter)
5392 self.ser.addInputTensor(a)
5393 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005394 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01005395
5396 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005397 cond_tens = self.ser.addOutput(
5398 [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
5399 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01005400 else:
5401 cond_tens = self.ser.addOutput([], DType.BOOL)
5402
Kevin Cheng550ccc52021-03-03 11:21:43 -08005403 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07005404
5405 # BODY block (input: a, acc, iter, output: a, acc, iter)
5406 # Note that local intermediate tensors need to be declared here for the outputs
5407 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01005408 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
5409 self.ser.addInputTensor(incorrect_iter)
5410 self.ser.addInputTensor(a)
5411 self.ser.addInputTensor(incorrect_acc)
5412 else:
5413 self.ser.addInputTensor(iter)
5414 self.ser.addInputTensor(a)
5415 self.ser.addInputTensor(acc)
5416
Kevin Cheng550ccc52021-03-03 11:21:43 -08005417 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01005418
5419 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005420 iter_body_out = self.ser.addIntermediate(
5421 incorrect_iter.shape, incorrect_iter.dtype
5422 )
5423 acc_body_out = self.ser.addIntermediate(
5424 incorrect_acc.shape, incorrect_acc.dtype
5425 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01005426 else:
5427 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
5428 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
5429
Eric Kunzee5e26762020-10-13 16:11:07 -07005430 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
5431 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
5432 self.ser.addOutputTensor(iter_body_out)
5433 self.ser.addOutputTensor(a)
5434 self.ser.addOutputTensor(acc_body_out)
5435
Les Bell729b0352021-11-24 10:28:21 +00005436 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01005437 self.ser,
5438 validator_fcns,
5439 error_name,
5440 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005441 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00005442 ):
5443 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01005444
Eric Kunzee5e26762020-10-13 16:11:07 -07005445 return acc_out
5446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005447 def create_filter_lists(
5448 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
5449 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005450 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
5451 default_test_rank_range = range(1, 5)
5452 if not shapeFilter:
5453 shapeFilter = [None]
5454
5455 # Calculate the filters based on what is requested and what the operator allows
5456 rmin, rmax = op["rank"]
5457 if rankFilter is not None:
5458 cleanRankFilter = []
5459 # Ensure rankFilter values are allowed by operator
5460 for rank in rankFilter:
5461 if rank >= rmin and rank <= rmax:
5462 cleanRankFilter.append(rank)
5463 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01005464 # Ensure default behaviour is bounded by default range or by operator,
5465 # whichever is the smaller range of ranks.
5466 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005467 cleanRankFilter = (
5468 opRankRange
5469 if len(opRankRange) <= len(default_test_rank_range)
5470 else default_test_rank_range
5471 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005472 else:
5473 cleanRankFilter = range(rmin, rmax + 1)
5474
5475 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005476
Matthew Haddon1c00b712021-10-01 15:51:03 +01005477 if dtypeFilter is not None:
5478 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01005479 # Create list of operator dtypes filtered by requested dtypes
5480 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005481 if dtype in dtypeFilter or (
5482 isinstance(dtype, list) and dtype[0] in dtypeFilter
5483 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005484 cleanDtypeFilter.append(dtype)
5485 else:
5486 cleanDtypeFilter = dtypes
5487
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005488 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01005489 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005490 "shapeFilter": shapeFilter,
5491 "rankFilter": cleanRankFilter,
5492 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01005493 }
5494 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005495 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01005496 if validator is not None:
5497 validator_info = validator(check=False, op=op)
5498 else:
5499 return None
5500
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005501 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005502
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005503 # Set parameters as required
5504 if error_arguments["rank"] is not None:
5505 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005506 else:
5507 rankFilter = cleanRankFilter
5508
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005509 if error_arguments["dtype"] is not None:
5510 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005511 else:
5512 dtypeFilter = cleanDtypeFilter
5513
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005514 if error_arguments["shape"] is not None:
5515 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005516 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005517 shapeFilter = shapeFilter[
5518 :2
5519 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01005520
5521 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005522 "shapeFilter": shapeFilter,
5523 "rankFilter": rankFilter,
5524 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01005525 }
5526 return filterDict
5527
Kevin Cheng550ccc52021-03-03 11:21:43 -08005528 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005529 self,
5530 opName,
5531 shapeFilter=[None],
5532 rankFilter=None,
5533 dtypeFilter=None,
5534 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08005535 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005536
5537 try:
5538 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005539 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005540 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005541
5542 # Initialize a new random number generator
5543 self.rng = np.random.default_rng(self.random_seed)
5544
Kevin Cheng550ccc52021-03-03 11:21:43 -08005545 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005546
Eric Kunzee5e26762020-10-13 16:11:07 -07005547 # Test list consists of a tuple of:
5548 # (opName, testNameStr, dtype, shapeList, argumentsList)
5549 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005550 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005551 error_if_validators = op["error_if_validators"]
5552 else:
5553 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07005554
Matthew Haddon1c00b712021-10-01 15:51:03 +01005555 for validator in error_if_validators:
5556 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005557 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005558 else:
5559 error_name = None
5560
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005561 filterDict = self.create_filter_lists(
5562 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
5563 )
5564 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01005565 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005566 cleanRankFilter = filterDict["rankFilter"]
5567 cleanDtypeFilter = filterDict["dtypeFilter"]
5568 cleanShapeFilter = filterDict["shapeFilter"]
5569 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005570
5571 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005572 for t in cleanDtypeFilter:
5573 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01005574 # Filter out by rank
5575 if shape is not None and len(shape) != r:
5576 continue
Matthew Haddon74567092021-07-16 15:38:20 +01005577 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005578 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005579
Matthew Haddon74567092021-07-16 15:38:20 +01005580 shapeStr = self.shapeStr(shapeList[0])
5581 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07005582
Matthew Haddon74567092021-07-16 15:38:20 +01005583 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
5584 argList = []
5585 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005586 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005587 else:
Matthew Haddon74567092021-07-16 15:38:20 +01005588 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07005589
Matthew Haddon74567092021-07-16 15:38:20 +01005590 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005591 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01005592 if argStr:
5593 testStr = "{}_{}_{}_{}".format(
5594 opName, shapeStr, typeStr, argStr
5595 )
5596 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005597 testStr = "{}_{}_{}".format(
5598 opName, shapeStr, typeStr
5599 )
5600 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01005601 if argStr:
5602 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
5603 opName, error_name, shapeStr, typeStr, argStr
5604 )
5605 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005606 testStr = "{}_ERRORIF_{}_{}_{}".format(
5607 opName, error_name, shapeStr, typeStr
5608 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005609
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005610 testList.append(
5611 (opName, testStr, t, error_name, shapeList, args)
5612 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005613
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005614 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01005615 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
5616 if "invalid_test_validators" in op:
5617 invalid_test_validators = op["invalid_test_validators"]
5618 clean_testList = []
5619 for test in testList:
5620 for validator_fcn in invalid_test_validators:
5621 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005622 if validator_fcn(
5623 opName=test[0],
5624 input_dtype=test[2],
5625 shapeList=test[4],
5626 args=test[5],
5627 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005628 remove_test = True
5629 if not remove_test:
5630 clean_testList.append(test)
5631 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07005632
5633 return testList
5634
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005635 def serializeTest(
5636 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
5637 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005638 try:
5639 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005640 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005641 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005642
5643 # Create a serializer
5644 self.createSerializer(opName, testStr)
5645
Kevin Cheng550ccc52021-03-03 11:21:43 -08005646 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01005647 if "error_if_validators" in op:
5648 error_if_validators = op["error_if_validators"]
5649 else:
5650 error_if_validators = None
5651
Kevin Cheng550ccc52021-03-03 11:21:43 -08005652 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07005653 num_operands = pCount + cCount
5654
5655 if isinstance(dtype_or_dtypeList, list):
5656 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07005657 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005658 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07005659 else:
5660 dtypeList = [dtype_or_dtypeList] * (num_operands)
5661
Kevin Cheng93a16282021-08-31 16:14:03 -07005662 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005663 assert (
5664 len(shapeList) == num_operands
5665 ), "shapeList length {} must match number of operands {}".format(
5666 len(shapeList), num_operands
5667 )
5668 assert (
5669 len(dtypeList) == num_operands
5670 ), "dtypeList length {} must match number of operands {}".format(
5671 len(dtypeList), num_operands
5672 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005673
5674 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005675 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005676 except KeyError:
5677 qgen = None
5678
5679 # Build the random tensor operands and the test
5680 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08005681
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005682 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005683
5684 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005685 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005686 else:
5687 qinfo = None
5688
5689 try:
5690 if error_if_validators is None:
5691 if qinfo is not None:
5692 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
5693 else:
5694 resultName = build_fcn(self, op, *tens, *testArgs)
5695 else:
5696 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005697 resultName = build_fcn(
5698 self,
5699 op,
5700 *tens,
5701 *testArgs,
5702 validator_fcns=error_if_validators,
5703 error_name=error_name,
5704 qinfo=qinfo,
5705 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005706 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005707 resultName = build_fcn(
5708 self,
5709 op,
5710 *tens,
5711 *testArgs,
5712 validator_fcns=error_if_validators,
5713 error_name=error_name,
5714 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005715 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00005716 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005717 raise e
5718
Les Bell729b0352021-11-24 10:28:21 +00005719 if resultName:
5720 # The test is valid, serialize it
5721 self.serialize("test")
5722 else:
5723 # The test is not valid
5724 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005725
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005726 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005727 pCount, cCount = op["operands"]
5728
5729 tens = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005730 if (
5731 (op["op"] == Op.ADD or op["op"] == Op.SUB)
5732 and dtypeList[0] == DType.INT32
5733 and error_name is None
5734 ):
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005735 # Make sure the operation does not cause value saturation - where
5736 # the number wraps due to limited number of bits to store the answer
5737 assert (
5738 pCount == 2 and cCount == 0
5739 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005740 placeholders = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005741 add = op["op"] == Op.ADD
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005742 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5743 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5744 if add:
5745 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
5746 else:
5747 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
5748
5749 # Work out the saturation limits
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005750 max_i32 = (1 << 31) - 1
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005751 min_i32 = -(1 << 31)
5752 max_arr = np.full(shapeList[1], max_i32)
5753 min_arr = np.full(shapeList[1], min_i32)
5754
5755 # Find how much values exceed the maximum/minimums
5756 sat_max_arr = np.maximum(res_arr - max_arr, 0)
5757 sat_min_arr = np.minimum(res_arr - min_arr, 0)
5758
5759 if not add:
5760 # Swap saturation values and negate values as we need to perform opposite operations
5761 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
5762
5763 # Create new array of unsaturated values by clipping values as needed
5764 b_unsat_arr = b_arr
5765 if (sat_max_arr != 0).any():
5766 # Clip values that cause saturation
5767 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
5768 # Reduce axes in unsaturated tensor to match original tensor
5769 for axis, dim in enumerate(b_arr.shape):
5770 if dim != b_unsat_arr.shape[axis]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005771 assert (
5772 dim == 1
5773 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005774 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
5775
5776 if (sat_min_arr != 0).any():
5777 # Clip values that cause saturation
5778 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
5779 # Reduce axes in unsaturated tensor to match original tensor
5780 for axis, dim in enumerate(b_arr.shape):
5781 if dim != b_unsat_arr.shape[axis]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005782 assert (
5783 dim == 1
5784 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005785 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
5786
5787 placeholders.append(
5788 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5789 )
5790 placeholders.append(
5791 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
5792 )
5793
5794 tens.extend(placeholders)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005795 elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[
5796 0
5797 ] == DType.INT32:
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005798 # Limit input tensors with cond_if_binary or while_loop to stop
5799 # saturation of add/sub ops
5800 pRemain = pCount
5801 placeholders = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005802 for idx, shape in enumerate(shapeList[:]):
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005803 arr = self.getRandTensor(shapeList[idx], DType.INT16)
5804 if pRemain > 0:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005805 placeholders.append(
5806 self.ser.addPlaceholder(shape, dtypeList[idx], arr)
5807 )
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005808 pRemain -= 1
5809 else:
5810 placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
5811
5812 tens.extend(placeholders)
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005813 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
5814 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005815 assert (
5816 pCount == 2 and cCount == 0
5817 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08005818
5819 placeholders = []
5820 for idx, shape in enumerate(shapeList[:]):
5821 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07005822 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005823 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005824 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005825 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005826 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005827 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005828 elif error_name == ErrorIf.WrongInputType:
5829 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005830 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005831 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08005832 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005833 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07005834 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005835
5836 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01005837 elif op["op"] == Op.SELECT:
5838 # Set datatype of condition tensor to boolean
5839 dtypeList[0] = DType.BOOL
5840 tens.extend(
5841 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5842 )
5843 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005844 elif op["op"] == Op.INTDIV and error_name is None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005845 assert (
5846 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01005847 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005848
5849 placeholders = []
5850
Matthew Haddon459443c2021-08-23 16:43:13 +01005851 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005852 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07005853 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005854 while True:
5855 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5856 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5857
5858 if (divisor_arr == 0).any():
5859 continue
5860
Kevin Cheng47315e12021-05-13 17:41:28 -07005861 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005862 continue
5863
5864 break
5865
5866 placeholders.append(
5867 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
5868 )
5869 placeholders.append(
5870 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
5871 )
5872
5873 tens.extend(placeholders)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005874 elif op["op"] == Op.MUL and error_name is None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005875 assert (
5876 pCount == 2 and cCount == 0
5877 ), "Op.MUL must have 2 placeholders, 0 consts"
5878
5879 if dtypeList[0] == DType.FLOAT:
5880 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
5881 else:
5882 placeholders = []
5883
5884 # Make sure multiply result in int32 range
5885 shift = testArgs[0]
5886 if dtypeList[0] == DType.INT8:
5887 num_bits = 8
5888 elif dtypeList[0] == DType.INT16:
5889 num_bits = 16
5890 elif dtypeList[0] == DType.INT32:
5891 num_bits = 32
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005892 elif error_name == ErrorIf.WrongInputType:
5893 num_bits = 8
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005894 else:
5895 raise Exception("OpMul: invalid input dtype")
5896
5897 for idx, shape in enumerate(shapeList[:]):
5898 low = -(2 ** (num_bits - 1))
5899 high = (2 ** (num_bits - 1)) - 1
5900
5901 a_arr = np.int32(
5902 self.rng.integers(low=low, high=high, size=shapeList[0])
5903 )
5904 b_arr = np.int32(
5905 self.rng.integers(low=low, high=high, size=shapeList[1])
5906 )
5907
5908 i = 0
5909 while True:
5910
5911 a_arr_64 = a_arr.astype(np.int64)
5912 b_arr_64 = b_arr.astype(np.int64)
5913
5914 if shift > 0:
5915 rounding = 1 << (shift - 1)
5916 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
5917 else:
5918 result_arr = a_arr_64 * b_arr_64
5919
5920 if (result_arr > -(2 ** 31)).all() and (
5921 result_arr <= ((2 ** 31) - 1)
5922 ).all():
5923 break
5924
5925 i = i + 1
5926 a_arr = a_arr // 2
5927 b_arr = b_arr // 2
5928
5929 placeholders.append(
5930 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5931 )
5932 placeholders.append(
5933 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
5934 )
5935
5936 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01005937 elif op["op"] == Op.CONCAT:
5938 count = len(shapeList) - self.args.num_const_inputs_concat
5939 if count < 1:
5940 count = 1
5941 if self.args.num_const_inputs_concat == 0:
5942 count = len(shapeList)
5943
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005944 # Ensure axis is an int
5945 testArgs[0] = int(testArgs[0])
5946
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005947 shapeList = TosaTensorGen.tgConcatConstInput(
5948 self, shapeList, testArgs[0], error_name
5949 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005950
Matthew Haddon818ab902021-07-27 09:12:49 +01005951 tens.extend(
5952 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
5953 )
5954 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005955 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07005956 tens.extend(
5957 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5958 )
5959 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07005960
Matthew Haddon1c00b712021-10-01 15:51:03 +01005961 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07005962
5963 def createDynamicOpLists(self):
5964
5965 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07005966 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005967
Kevin Cheng1533b852021-09-01 12:51:58 -07005968 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005969 testName = "conv2d_{}x{}".format(k[0], k[1])
5970 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
5971 self.TOSA_OP_LIST[testName]["filter"] = k
5972 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005973
Kevin Cheng550ccc52021-03-03 11:21:43 -08005974 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
5975 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5976 "depthwise_conv2d_TEMPLATE"
5977 ].copy()
5978 self.TOSA_OP_LIST[testName]["filter"] = k
5979 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005980
Kevin Cheng550ccc52021-03-03 11:21:43 -08005981 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
5982 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
5983 "transpose_conv2d_TEMPLATE"
5984 ].copy()
5985 self.TOSA_OP_LIST[testName]["filter"] = k
5986 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07005987
Kevin Cheng1533b852021-09-01 12:51:58 -07005988 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
5989 for k in KERNELS_3D:
5990 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
5991 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
5992 self.TOSA_OP_LIST[testName]["filter"] = k
5993 self.TOSA_OP_LIST[testName]["template"] = False
5994
Eric Kunzee5e26762020-10-13 16:11:07 -07005995 # Delete any templates after having created any dynamic ops
5996 # This is a two-pass operation because it's bad practice to delete
5997 # keys from dictionaries while iterating
5998 keyList = []
5999 for k in self.TOSA_OP_LIST:
6000 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006001 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07006002 keyList.append(k)
6003 continue
6004 except KeyError:
6005 pass
6006
6007 for k in keyList:
6008 del self.TOSA_OP_LIST[k]
6009
6010 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006011 """Fill in default fields for ops if they aren't already specified.
6012 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07006013 for op in self.TOSA_OP_LIST:
6014
6015 # Required fields
6016 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006017 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07006018 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006019 raise Exception(
6020 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
6021 )
Eric Kunzee5e26762020-10-13 16:11:07 -07006022
6023 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006024 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07006025 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006026 raise Exception(
6027 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
6028 op
6029 )
6030 )
Eric Kunzee5e26762020-10-13 16:11:07 -07006031
6032 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006033 _ = self.TOSA_OP_LIST[op]["types"]
6034 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006035 raise Exception(
6036 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
6037 )
Eric Kunzee5e26762020-10-13 16:11:07 -07006038
6039 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006040 _ = self.TOSA_OP_LIST[op]["op"]
6041 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006042 raise Exception(
6043 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
6044 )
Eric Kunzee5e26762020-10-13 16:11:07 -07006045
6046 # Put in default rank range, if missing
6047 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006048 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07006049 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006050 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07006051
6052 # Tensor operator list
6053 # 'op': op name
6054 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08006055 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
6056 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07006057 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
6058 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08006059 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07006060
Kevin Cheng550ccc52021-03-03 11:21:43 -08006061 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
6062 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07006063
Kevin Cheng550ccc52021-03-03 11:21:43 -08006064 TYPE_BOOL = [DType.BOOL]
6065 TYPE_FI32 = [DType.FLOAT, DType.INT32]
6066 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
6067 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07006068
Kevin Cheng550ccc52021-03-03 11:21:43 -08006069 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07006070
Kevin Cheng1533b852021-09-01 12:51:58 -07006071 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07006072 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07006073 [DType.INT8, DType.INT8, DType.INT32],
6074 [DType.INT16, DType.INT8, DType.INT48],
6075 DType.FLOAT,
6076 ]
6077
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01006078 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07006079
6080 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08006081 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08006082 "argmax": {
6083 "op": Op.ARGMAX,
6084 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006085 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006086 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6087 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006088 "error_if_validators": (
6089 TosaErrorValidator.evAxisSmallerZero,
6090 TosaErrorValidator.evAxisLargerRank,
6091 TosaErrorValidator.evArgmaxOutputRankMismatch,
6092 TosaErrorValidator.evArgmaxOutputShapeMismatch,
6093 TosaErrorValidator.evWrongRank,
6094 TosaErrorValidator.evWrongInputType,
6095 TosaErrorValidator.evWrongOutputType,
6096 TosaErrorValidator.evWrongInputList,
6097 TosaErrorValidator.evWrongOutputList,
6098 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006099 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006100 "avg_pool2d": {
6101 "op": Op.AVG_POOL2D,
6102 "operands": (1, 0),
6103 "rank": (4, 4),
6104 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
6105 "qgen": TosaQuantGen.qgUnary,
6106 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00006107 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006108 "error_if_validators": (
6109 TosaErrorValidator.evKernelSmallerOne,
6110 TosaErrorValidator.evStrideSmallerOne,
6111 TosaErrorValidator.evPadSmallerZero,
6112 TosaErrorValidator.evWrongRank,
6113 TosaErrorValidator.evWrongInputType,
6114 TosaErrorValidator.evWrongOutputType,
6115 TosaErrorValidator.evWrongInputList,
6116 TosaErrorValidator.evWrongOutputList,
6117 TosaErrorValidator.evInputZeroPointNotZero,
6118 TosaErrorValidator.evOutputZeroPointNotZero,
6119 TosaErrorValidator.evPadLargerEqualKernel,
6120 TosaErrorValidator.evPoolingOutputShapeMismatch,
6121 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006122 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006123 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08006124 "conv2d_TEMPLATE": {
6125 "op": Op.CONV2D,
6126 "operands": (1, 2),
6127 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01006128 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006129 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07006130 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00006131 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
6132 "error_if_validators": (
6133 TosaErrorValidator.evWrongInputType,
6134 TosaErrorValidator.evWrongOutputType,
6135 TosaErrorValidator.evWrongInputList,
6136 TosaErrorValidator.evWrongOutputList,
6137 TosaErrorValidator.evInputZeroPointNotZero,
6138 TosaErrorValidator.evWeightZeroPointNotZero,
6139 TosaErrorValidator.evPadSmallerZero,
6140 TosaErrorValidator.evStrideSmallerOne,
6141 TosaErrorValidator.evDilationSmallerOne,
6142 TosaErrorValidator.evWrongRank,
6143 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006144 "template": True,
6145 },
Kevin Cheng1533b852021-09-01 12:51:58 -07006146 # Templated operator. Filled in by createDynamicOpLists
6147 "conv3d_TEMPLATE": {
6148 "op": Op.CONV3D,
6149 "operands": (1, 2),
6150 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01006151 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07006152 "qgen": TosaQuantGen.qgConv,
6153 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00006154 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
6155 "error_if_validators": (
6156 TosaErrorValidator.evWrongInputType,
6157 TosaErrorValidator.evWrongOutputType,
6158 TosaErrorValidator.evWrongInputList,
6159 TosaErrorValidator.evWrongOutputList,
6160 TosaErrorValidator.evInputZeroPointNotZero,
6161 TosaErrorValidator.evWeightZeroPointNotZero,
6162 TosaErrorValidator.evPadSmallerZero,
6163 TosaErrorValidator.evStrideSmallerOne,
6164 TosaErrorValidator.evDilationSmallerOne,
6165 TosaErrorValidator.evWrongRank,
6166 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07006167 "template": True,
6168 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006169 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08006170 "depthwise_conv2d_TEMPLATE": {
6171 "op": Op.DEPTHWISE_CONV2D,
6172 "operands": (1, 2),
6173 "filter": [1, 1],
6174 "rank": (4, 4),
6175 "build_fcn": (
6176 build_depthwise_conv2d,
6177 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01006178 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08006179 ),
6180 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07006181 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00006182 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
6183 "error_if_validators": (
6184 TosaErrorValidator.evWrongInputType,
6185 TosaErrorValidator.evWrongOutputType,
6186 TosaErrorValidator.evWrongInputList,
6187 TosaErrorValidator.evWrongOutputList,
6188 TosaErrorValidator.evInputZeroPointNotZero,
6189 TosaErrorValidator.evWeightZeroPointNotZero,
6190 TosaErrorValidator.evPadSmallerZero,
6191 TosaErrorValidator.evStrideSmallerOne,
6192 TosaErrorValidator.evDilationSmallerOne,
6193 TosaErrorValidator.evWrongRank,
6194 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006195 "template": True,
6196 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006197 "fully_connected": {
6198 "op": Op.FULLY_CONNECTED,
6199 "operands": (1, 2),
6200 "rank": (2, 2),
6201 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
6202 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07006203 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006204 "error_if_validators": (
6205 TosaErrorValidator.evInputZeroPointNotZero,
6206 TosaErrorValidator.evWeightZeroPointNotZero,
6207 TosaErrorValidator.evWrongRank,
6208 TosaErrorValidator.evWrongInputType,
6209 TosaErrorValidator.evWrongOutputType,
6210 TosaErrorValidator.evWrongInputList,
6211 TosaErrorValidator.evWrongOutputList,
6212 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006213 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006214 "matmul": {
6215 "op": Op.MATMUL,
6216 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07006217 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08006218 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
6219 "qgen": TosaQuantGen.qgMatmul,
6220 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006221 "error_if_validators": (
6222 TosaErrorValidator.evInputZeroPointNotZero,
6223 TosaErrorValidator.evWrongRank,
6224 TosaErrorValidator.evWrongInputType,
6225 TosaErrorValidator.evWrongOutputType,
6226 TosaErrorValidator.evWrongInputList,
6227 TosaErrorValidator.evWrongOutputList,
6228 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006229 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006230 "max_pool2d": {
6231 "op": Op.MAX_POOL2D,
6232 "operands": (1, 0),
6233 "rank": (4, 4),
6234 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
6235 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00006236 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006237 "error_if_validators": (
6238 TosaErrorValidator.evKernelSmallerOne,
6239 TosaErrorValidator.evStrideSmallerOne,
6240 TosaErrorValidator.evPadSmallerZero,
6241 TosaErrorValidator.evWrongRank,
6242 TosaErrorValidator.evWrongInputType,
6243 TosaErrorValidator.evWrongOutputType,
6244 TosaErrorValidator.evWrongInputList,
6245 TosaErrorValidator.evWrongOutputList,
6246 TosaErrorValidator.evPadLargerEqualKernel,
6247 TosaErrorValidator.evPoolingOutputShapeMismatch,
6248 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006249 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006250 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08006251 "transpose_conv2d_TEMPLATE": {
6252 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07006253 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006254 "rank": (4, 4),
6255 "build_fcn": (
6256 build_transpose_conv2d,
6257 TosaTensorGen.tgTransposeConv2D,
6258 TosaArgGen.agTransposeConv2D,
6259 ),
6260 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07006261 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00006262 "invalid_test_validators": (
6263 TosaInvalidValidator.ivHeightWidthInvalid,
6264 TosaInvalidValidator.ivNonPositiveOutputShape,
6265 ),
6266 "error_if_validators": (
6267 TosaErrorValidator.evWrongInputType,
6268 TosaErrorValidator.evWrongOutputType,
6269 TosaErrorValidator.evWrongInputList,
6270 TosaErrorValidator.evWrongOutputList,
6271 TosaErrorValidator.evInputZeroPointNotZero,
6272 TosaErrorValidator.evWeightZeroPointNotZero,
6273 TosaErrorValidator.evPadSmallerZero,
6274 TosaErrorValidator.evStrideSmallerOne,
6275 TosaErrorValidator.evDilationSmallerOne,
6276 TosaErrorValidator.evWrongRank,
6277 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006278 "template": True,
6279 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006280 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08006281 "clamp": {
6282 "op": Op.CLAMP,
6283 "operands": (1, 0),
6284 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
6285 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006286 "error_if_validators": (
6287 TosaErrorValidator.evMaxSmallerMin,
6288 TosaErrorValidator.evWrongInputType,
6289 TosaErrorValidator.evWrongOutputType,
6290 TosaErrorValidator.evWrongInputList,
6291 TosaErrorValidator.evWrongOutputList,
6292 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006293 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08006294 "sigmoid": {
6295 "op": Op.SIGMOID,
6296 "operands": (1, 0),
6297 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
6298 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006299 "error_if_validators": (
6300 TosaErrorValidator.evWrongInputType,
6301 TosaErrorValidator.evWrongOutputType,
6302 TosaErrorValidator.evWrongInputList,
6303 TosaErrorValidator.evWrongOutputList,
6304 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006305 },
6306 "tanh": {
6307 "op": Op.TANH,
6308 "operands": (1, 0),
6309 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
6310 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006311 "error_if_validators": (
6312 TosaErrorValidator.evWrongInputType,
6313 TosaErrorValidator.evWrongOutputType,
6314 TosaErrorValidator.evWrongInputList,
6315 TosaErrorValidator.evWrongOutputList,
6316 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006317 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006318 # Elementwise Binary Operators
6319 "add": {
6320 "op": Op.ADD,
6321 "operands": (2, 0),
6322 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6323 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006324 "error_if_validators": (
6325 TosaErrorValidator.evRankMismatch,
6326 TosaErrorValidator.evWrongInputType,
6327 TosaErrorValidator.evWrongOutputType,
6328 TosaErrorValidator.evWrongInputList,
6329 TosaErrorValidator.evWrongOutputList,
6330 TosaErrorValidator.evDimensionMismatch,
6331 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006332 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006333 "arithmetic_right_shift": {
6334 "op": Op.ARITHMETIC_RIGHT_SHIFT,
6335 "operands": (2, 0),
6336 "build_fcn": (
6337 build_arithmetic_right_shift,
6338 TosaTensorGen.tgBroadcastFuzz,
6339 TosaArgGen.agArithmeticRightShift,
6340 ),
6341 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006342 "error_if_validators": (
6343 TosaErrorValidator.evRankMismatch,
6344 TosaErrorValidator.evWrongInputType,
6345 TosaErrorValidator.evWrongOutputType,
6346 TosaErrorValidator.evWrongInputList,
6347 TosaErrorValidator.evWrongOutputList,
6348 TosaErrorValidator.evDimensionMismatch,
6349 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006350 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006351 "bitwise_and": {
6352 "op": Op.BITWISE_AND,
6353 "operands": (2, 0),
6354 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6355 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006356 "error_if_validators": (
6357 TosaErrorValidator.evRankMismatch,
6358 TosaErrorValidator.evWrongInputType,
6359 TosaErrorValidator.evWrongOutputType,
6360 TosaErrorValidator.evWrongInputList,
6361 TosaErrorValidator.evWrongOutputList,
6362 TosaErrorValidator.evDimensionMismatch,
6363 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006364 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006365 "bitwise_or": {
6366 "op": Op.BITWISE_OR,
6367 "operands": (2, 0),
6368 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6369 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006370 "error_if_validators": (
6371 TosaErrorValidator.evRankMismatch,
6372 TosaErrorValidator.evWrongInputType,
6373 TosaErrorValidator.evWrongOutputType,
6374 TosaErrorValidator.evWrongInputList,
6375 TosaErrorValidator.evWrongOutputList,
6376 TosaErrorValidator.evDimensionMismatch,
6377 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006378 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006379 "bitwise_xor": {
6380 "op": Op.BITWISE_XOR,
6381 "operands": (2, 0),
6382 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6383 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006384 "error_if_validators": (
6385 TosaErrorValidator.evRankMismatch,
6386 TosaErrorValidator.evWrongInputType,
6387 TosaErrorValidator.evWrongOutputType,
6388 TosaErrorValidator.evWrongInputList,
6389 TosaErrorValidator.evWrongOutputList,
6390 TosaErrorValidator.evDimensionMismatch,
6391 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006392 },
Matthew Haddon459443c2021-08-23 16:43:13 +01006393 "intdiv": {
6394 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07006395 "operands": (2, 0),
6396 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6397 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006398 "error_if_validators": (
6399 TosaErrorValidator.evRankMismatch,
6400 TosaErrorValidator.evWrongInputType,
6401 TosaErrorValidator.evWrongOutputType,
6402 TosaErrorValidator.evWrongInputList,
6403 TosaErrorValidator.evWrongOutputList,
6404 TosaErrorValidator.evDimensionMismatch,
6405 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07006406 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006407 "logical_and": {
6408 "op": Op.LOGICAL_AND,
6409 "operands": (2, 0),
6410 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6411 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006412 "error_if_validators": (
6413 TosaErrorValidator.evRankMismatch,
6414 TosaErrorValidator.evWrongInputType,
6415 TosaErrorValidator.evWrongOutputType,
6416 TosaErrorValidator.evWrongInputList,
6417 TosaErrorValidator.evWrongOutputList,
6418 TosaErrorValidator.evDimensionMismatch,
6419 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006420 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006421 "logical_left_shift": {
6422 "op": Op.LOGICAL_LEFT_SHIFT,
6423 "operands": (2, 0),
6424 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6425 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006426 "error_if_validators": (
6427 TosaErrorValidator.evRankMismatch,
6428 TosaErrorValidator.evWrongInputType,
6429 TosaErrorValidator.evWrongOutputType,
6430 TosaErrorValidator.evWrongInputList,
6431 TosaErrorValidator.evWrongOutputList,
6432 TosaErrorValidator.evDimensionMismatch,
6433 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006434 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006435 "logical_right_shift": {
6436 "op": Op.LOGICAL_RIGHT_SHIFT,
6437 "operands": (2, 0),
6438 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6439 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006440 "error_if_validators": (
6441 TosaErrorValidator.evRankMismatch,
6442 TosaErrorValidator.evWrongInputType,
6443 TosaErrorValidator.evWrongOutputType,
6444 TosaErrorValidator.evWrongInputList,
6445 TosaErrorValidator.evWrongOutputList,
6446 TosaErrorValidator.evDimensionMismatch,
6447 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006448 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006449 "logical_or": {
6450 "op": Op.LOGICAL_OR,
6451 "operands": (2, 0),
6452 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6453 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006454 "error_if_validators": (
6455 TosaErrorValidator.evRankMismatch,
6456 TosaErrorValidator.evWrongInputType,
6457 TosaErrorValidator.evWrongOutputType,
6458 TosaErrorValidator.evWrongInputList,
6459 TosaErrorValidator.evWrongOutputList,
6460 TosaErrorValidator.evDimensionMismatch,
6461 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006462 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006463 "logical_xor": {
6464 "op": Op.LOGICAL_XOR,
6465 "operands": (2, 0),
6466 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6467 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006468 "error_if_validators": (
6469 TosaErrorValidator.evRankMismatch,
6470 TosaErrorValidator.evWrongInputType,
6471 TosaErrorValidator.evWrongOutputType,
6472 TosaErrorValidator.evWrongInputList,
6473 TosaErrorValidator.evWrongOutputList,
6474 TosaErrorValidator.evDimensionMismatch,
6475 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006476 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006477 "maximum": {
6478 "op": Op.MAXIMUM,
6479 "operands": (2, 0),
6480 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6481 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006482 "error_if_validators": (
6483 TosaErrorValidator.evRankMismatch,
6484 TosaErrorValidator.evWrongInputType,
6485 TosaErrorValidator.evWrongOutputType,
6486 TosaErrorValidator.evWrongInputList,
6487 TosaErrorValidator.evWrongOutputList,
6488 TosaErrorValidator.evDimensionMismatch,
6489 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006490 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006491 "minimum": {
6492 "op": Op.MINIMUM,
6493 "operands": (2, 0),
6494 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6495 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006496 "error_if_validators": (
6497 TosaErrorValidator.evRankMismatch,
6498 TosaErrorValidator.evWrongInputType,
6499 TosaErrorValidator.evWrongOutputType,
6500 TosaErrorValidator.evWrongInputList,
6501 TosaErrorValidator.evWrongOutputList,
6502 TosaErrorValidator.evDimensionMismatch,
6503 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006504 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006505 "mul": {
6506 "op": Op.MUL,
6507 "operands": (2, 0),
6508 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
6509 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006510 "error_if_validators": (
6511 TosaErrorValidator.evWrongInputType,
6512 TosaErrorValidator.evWrongOutputType,
6513 TosaErrorValidator.evWrongInputList,
6514 TosaErrorValidator.evWrongOutputList,
6515 TosaErrorValidator.evRankMismatch,
6516 TosaErrorValidator.evDimensionMismatch,
6517 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006518 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006519 "pow": {
6520 "op": Op.POW,
6521 "operands": (2, 0),
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006522 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08006523 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006524 "error_if_validators": (
6525 TosaErrorValidator.evRankMismatch,
6526 TosaErrorValidator.evWrongInputType,
6527 TosaErrorValidator.evWrongOutputType,
6528 TosaErrorValidator.evWrongInputList,
6529 TosaErrorValidator.evWrongOutputList,
6530 TosaErrorValidator.evDimensionMismatch,
6531 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006532 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006533 "sub": {
6534 "op": Op.SUB,
6535 "operands": (2, 0),
6536 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6537 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006538 "error_if_validators": (
6539 TosaErrorValidator.evRankMismatch,
6540 TosaErrorValidator.evWrongInputType,
6541 TosaErrorValidator.evWrongOutputType,
6542 TosaErrorValidator.evWrongInputList,
6543 TosaErrorValidator.evWrongOutputList,
6544 TosaErrorValidator.evDimensionMismatch,
6545 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006546 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006547 "table": {
6548 "op": Op.TABLE,
6549 # Use the automatic generation functions to create the input array
6550 # but create the table tensor in the build function, as it may be
6551 # a different type from the input
6552 "operands": (1, 0),
Kevin Chengfe392ce2021-10-18 21:51:55 +00006553 "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01006554 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006555 "error_if_validators": (
6556 TosaErrorValidator.evWrongInputType,
6557 TosaErrorValidator.evWrongOutputType,
6558 TosaErrorValidator.evWrongInputList,
6559 TosaErrorValidator.evWrongOutputList,
6560 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006561 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006562 # Elementwise Unary operators
6563 "abs": {
6564 "op": Op.ABS,
6565 "operands": (1, 0),
6566 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6567 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006568 "error_if_validators": (
6569 TosaErrorValidator.evWrongInputType,
6570 TosaErrorValidator.evWrongOutputType,
6571 TosaErrorValidator.evWrongInputList,
6572 TosaErrorValidator.evWrongOutputList,
6573 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006574 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006575 "bitwise_not": {
6576 "op": Op.BITWISE_NOT,
6577 "operands": (1, 0),
6578 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6579 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006580 "error_if_validators": (
6581 TosaErrorValidator.evWrongInputType,
6582 TosaErrorValidator.evWrongOutputType,
6583 TosaErrorValidator.evWrongInputList,
6584 TosaErrorValidator.evWrongOutputList,
6585 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006586 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006587 "ceil": {
6588 "op": Op.CEIL,
6589 "operands": (1, 0),
6590 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6591 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006592 "error_if_validators": (
6593 TosaErrorValidator.evWrongInputType,
6594 TosaErrorValidator.evWrongOutputType,
6595 TosaErrorValidator.evWrongInputList,
6596 TosaErrorValidator.evWrongOutputList,
6597 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006598 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006599 "clz": {
6600 "op": Op.CLZ,
6601 "operands": (1, 0),
6602 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6603 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006604 "error_if_validators": (
6605 TosaErrorValidator.evWrongInputType,
6606 TosaErrorValidator.evWrongOutputType,
6607 TosaErrorValidator.evWrongInputList,
6608 TosaErrorValidator.evWrongOutputList,
6609 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006610 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006611 "exp": {
6612 "op": Op.EXP,
6613 "operands": (1, 0),
6614 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6615 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006616 "error_if_validators": (
6617 TosaErrorValidator.evWrongInputType,
6618 TosaErrorValidator.evWrongOutputType,
6619 TosaErrorValidator.evWrongInputList,
6620 TosaErrorValidator.evWrongOutputList,
6621 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006622 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006623 "floor": {
6624 "op": Op.FLOOR,
6625 "operands": (1, 0),
6626 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6627 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006628 "error_if_validators": (
6629 TosaErrorValidator.evWrongInputType,
6630 TosaErrorValidator.evWrongOutputType,
6631 TosaErrorValidator.evWrongInputList,
6632 TosaErrorValidator.evWrongOutputList,
6633 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006634 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006635 "log": {
6636 "op": Op.LOG,
6637 "operands": (1, 0),
6638 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6639 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006640 "error_if_validators": (
6641 TosaErrorValidator.evWrongInputType,
6642 TosaErrorValidator.evWrongOutputType,
6643 TosaErrorValidator.evWrongInputList,
6644 TosaErrorValidator.evWrongOutputList,
6645 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006646 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006647 "logical_not": {
6648 "op": Op.LOGICAL_NOT,
6649 "operands": (1, 0),
6650 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6651 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006652 "error_if_validators": (
6653 TosaErrorValidator.evWrongInputType,
6654 TosaErrorValidator.evWrongOutputType,
6655 TosaErrorValidator.evWrongInputList,
6656 TosaErrorValidator.evWrongOutputList,
6657 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006658 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006659 "negate": {
6660 "op": Op.NEGATE,
6661 "operands": (1, 0),
6662 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6663 "qgen": TosaQuantGen.qgUnary,
6664 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006665 "error_if_validators": (
6666 TosaErrorValidator.evInputZeroPointNotZero,
6667 TosaErrorValidator.evOutputZeroPointNotZero,
6668 TosaErrorValidator.evWrongInputType,
6669 TosaErrorValidator.evWrongOutputType,
6670 TosaErrorValidator.evWrongInputList,
6671 TosaErrorValidator.evWrongOutputList,
6672 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006673 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006674 "reciprocal": {
6675 "op": Op.RECIPROCAL,
6676 "operands": (1, 0),
6677 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6678 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006679 "error_if_validators": (
6680 TosaErrorValidator.evWrongInputType,
6681 TosaErrorValidator.evWrongOutputType,
6682 TosaErrorValidator.evWrongInputList,
6683 TosaErrorValidator.evWrongOutputList,
6684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006685 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006686 "rsqrt": {
6687 "op": Op.RSQRT,
6688 "operands": (1, 0),
6689 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6690 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006691 "error_if_validators": (
6692 TosaErrorValidator.evWrongInputType,
6693 TosaErrorValidator.evWrongOutputType,
6694 TosaErrorValidator.evWrongInputList,
6695 TosaErrorValidator.evWrongOutputList,
6696 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006697 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006698 # Elementwise Ternary operators
6699 "select": {
6700 "op": Op.SELECT,
6701 "operands": (3, 0),
6702 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
6703 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006704 "error_if_validators": (
6705 TosaErrorValidator.evRankMismatch,
6706 TosaErrorValidator.evWrongInputType,
6707 TosaErrorValidator.evWrongOutputType,
6708 TosaErrorValidator.evWrongInputList,
6709 TosaErrorValidator.evWrongOutputList,
6710 TosaErrorValidator.evDimensionMismatch,
6711 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006712 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006713 # Comparison operators
6714 "equal": {
6715 "op": Op.EQUAL,
6716 "operands": (2, 0),
6717 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
6718 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006719 "error_if_validators": (
6720 TosaErrorValidator.evRankMismatch,
6721 TosaErrorValidator.evWrongInputType,
6722 TosaErrorValidator.evWrongOutputType,
6723 TosaErrorValidator.evWrongInputList,
6724 TosaErrorValidator.evWrongOutputList,
6725 TosaErrorValidator.evDimensionMismatch,
6726 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006727 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006728 "greater_equal": {
6729 "op": Op.GREATER_EQUAL,
6730 "operands": (2, 0),
6731 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
6732 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006733 "error_if_validators": (
6734 TosaErrorValidator.evRankMismatch,
6735 TosaErrorValidator.evWrongInputType,
6736 TosaErrorValidator.evWrongOutputType,
6737 TosaErrorValidator.evWrongInputList,
6738 TosaErrorValidator.evWrongOutputList,
6739 TosaErrorValidator.evDimensionMismatch,
6740 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006741 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006742 "greater": {
6743 "op": Op.GREATER,
6744 "operands": (2, 0),
6745 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
6746 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006747 "error_if_validators": (
6748 TosaErrorValidator.evRankMismatch,
6749 TosaErrorValidator.evWrongInputType,
6750 TosaErrorValidator.evWrongOutputType,
6751 TosaErrorValidator.evWrongInputList,
6752 TosaErrorValidator.evWrongOutputList,
6753 TosaErrorValidator.evDimensionMismatch,
6754 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006755 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006756 # Reduction operators
6757 "reduce_all": {
6758 "op": Op.REDUCE_ALL,
6759 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006760 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006761 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6762 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006763 "error_if_validators": (
6764 TosaErrorValidator.evAxisLargerRank,
6765 TosaErrorValidator.evAxisSmallerZero,
6766 TosaErrorValidator.evShapeOfAxisNotOne,
6767 TosaErrorValidator.evWrongInputType,
6768 TosaErrorValidator.evWrongOutputType,
6769 TosaErrorValidator.evWrongRank,
6770 TosaErrorValidator.evWrongInputList,
6771 TosaErrorValidator.evWrongOutputList,
6772 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006773 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006774 "reduce_any": {
6775 "op": Op.REDUCE_ANY,
6776 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006777 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006778 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6779 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006780 "error_if_validators": (
6781 TosaErrorValidator.evAxisLargerRank,
6782 TosaErrorValidator.evAxisSmallerZero,
6783 TosaErrorValidator.evShapeOfAxisNotOne,
6784 TosaErrorValidator.evWrongInputType,
6785 TosaErrorValidator.evWrongOutputType,
6786 TosaErrorValidator.evWrongRank,
6787 TosaErrorValidator.evWrongInputList,
6788 TosaErrorValidator.evWrongOutputList,
6789 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006790 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006791 "reduce_max": {
6792 "op": Op.REDUCE_MAX,
6793 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006794 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006795 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6796 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006797 "error_if_validators": (
6798 TosaErrorValidator.evAxisLargerRank,
6799 TosaErrorValidator.evAxisSmallerZero,
6800 TosaErrorValidator.evShapeOfAxisNotOne,
6801 TosaErrorValidator.evWrongInputType,
6802 TosaErrorValidator.evWrongOutputType,
6803 TosaErrorValidator.evWrongRank,
6804 TosaErrorValidator.evWrongInputList,
6805 TosaErrorValidator.evWrongOutputList,
6806 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006807 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006808 "reduce_min": {
6809 "op": Op.REDUCE_MAX,
6810 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006811 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006812 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6813 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006814 "error_if_validators": (
6815 TosaErrorValidator.evAxisLargerRank,
6816 TosaErrorValidator.evAxisSmallerZero,
6817 TosaErrorValidator.evShapeOfAxisNotOne,
6818 TosaErrorValidator.evWrongInputType,
6819 TosaErrorValidator.evWrongOutputType,
6820 TosaErrorValidator.evWrongRank,
6821 TosaErrorValidator.evWrongInputList,
6822 TosaErrorValidator.evWrongOutputList,
6823 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006824 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006825 "reduce_product": {
6826 "op": Op.REDUCE_PRODUCT,
6827 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006828 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006829 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6830 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006831 "error_if_validators": (
6832 TosaErrorValidator.evAxisLargerRank,
6833 TosaErrorValidator.evAxisSmallerZero,
6834 TosaErrorValidator.evShapeOfAxisNotOne,
6835 TosaErrorValidator.evWrongInputType,
6836 TosaErrorValidator.evWrongOutputType,
6837 TosaErrorValidator.evWrongRank,
6838 TosaErrorValidator.evWrongInputList,
6839 TosaErrorValidator.evWrongOutputList,
6840 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006841 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006842 "reduce_sum": {
6843 "op": Op.REDUCE_SUM,
6844 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006845 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006846 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6847 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006848 "error_if_validators": (
6849 TosaErrorValidator.evAxisLargerRank,
6850 TosaErrorValidator.evAxisSmallerZero,
6851 TosaErrorValidator.evShapeOfAxisNotOne,
6852 TosaErrorValidator.evWrongInputType,
6853 TosaErrorValidator.evWrongOutputType,
6854 TosaErrorValidator.evWrongRank,
6855 TosaErrorValidator.evWrongInputList,
6856 TosaErrorValidator.evWrongOutputList,
6857 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006858 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006859 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08006860 "concat": {
6861 "op": Op.CONCAT,
6862 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01006863 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006864 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006865 "error_if_validators": (
6866 TosaErrorValidator.evAxisLargerRank,
6867 TosaErrorValidator.evAxisSmallerZero,
6868 TosaErrorValidator.evConcatInputRankMismatch,
6869 TosaErrorValidator.evConcatShapeSumMismatch,
6870 TosaErrorValidator.evConcatInputDimMismatch,
6871 TosaErrorValidator.evWrongInputType,
6872 TosaErrorValidator.evWrongOutputType,
6873 TosaErrorValidator.evWrongOutputList,
6874 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006875 },
6876 "pad": {
6877 "op": Op.PAD,
6878 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006879 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006880 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
6881 "qgen": TosaQuantGen.qgPad,
6882 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006883 "error_if_validators": (
6884 TosaErrorValidator.evWrongInputType,
6885 TosaErrorValidator.evPadSmallerZero,
6886 TosaErrorValidator.evWrongOutputType,
6887 TosaErrorValidator.evWrongInputList,
6888 TosaErrorValidator.evWrongOutputList,
6889 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006890 },
6891 "reshape": {
6892 "op": Op.RESHAPE,
6893 "operands": (1, 0),
6894 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
6895 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006896 "error_if_validators": (
6897 TosaErrorValidator.evTensorSizeInputOutputMismatch,
6898 TosaErrorValidator.evWrongInputType,
6899 TosaErrorValidator.evWrongOutputType,
6900 TosaErrorValidator.evWrongInputList,
6901 TosaErrorValidator.evWrongOutputList,
6902 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006903 },
6904 "reverse": {
6905 "op": Op.REVERSE,
6906 "operands": (1, 0),
6907 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6908 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006909 "error_if_validators": (
6910 TosaErrorValidator.evAxisSmallerZero,
6911 TosaErrorValidator.evAxisLargerRank,
6912 TosaErrorValidator.evWrongInputType,
6913 TosaErrorValidator.evWrongOutputType,
6914 TosaErrorValidator.evWrongInputList,
6915 TosaErrorValidator.evWrongOutputList,
6916 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006917 },
6918 "slice": {
6919 "op": Op.SLICE,
6920 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006921 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006922 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
6923 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006924 "error_if_validators": (
6925 TosaErrorValidator.evStartSmallerZero,
6926 TosaErrorValidator.evSizeSmallerEqualZero,
6927 TosaErrorValidator.evStartSizeOutsideBounds,
6928 TosaErrorValidator.evSizeOutputShapeMismatch,
6929 TosaErrorValidator.evInputSizeStartLengthMismatch,
6930 TosaErrorValidator.evWrongRank,
6931 TosaErrorValidator.evWrongInputType,
6932 TosaErrorValidator.evWrongOutputType,
6933 TosaErrorValidator.evWrongInputList,
6934 TosaErrorValidator.evWrongOutputList,
6935 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006936 },
6937 "tile": {
6938 "op": Op.TILE,
6939 "operands": (1, 0),
6940 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
6941 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006942 "error_if_validators": (
6943 TosaErrorValidator.evWrongInputType,
6944 TosaErrorValidator.evWrongOutputType,
6945 TosaErrorValidator.evWrongInputList,
6946 TosaErrorValidator.evWrongOutputList,
6947 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006948 },
6949 "transpose": {
6950 "op": Op.TRANSPOSE,
6951 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01006952 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006953 "build_fcn": (
6954 build_transpose,
6955 TosaTensorGen.tgBasic,
6956 TosaArgGen.agTranspose,
6957 ),
6958 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006959 "error_if_validators": (
6960 TosaErrorValidator.evIndexOutsideBounds,
6961 TosaErrorValidator.evIndexUsedTwice,
6962 TosaErrorValidator.evWrongInputType,
6963 TosaErrorValidator.evWrongOutputType,
6964 TosaErrorValidator.evWrongInputList,
6965 TosaErrorValidator.evWrongOutputList,
6966 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006967 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006968 # Data nodes
6969 "const": {
6970 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07006971 "operands": (0, 1),
6972 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08006973 "types": TYPE_FIB,
6974 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006975 "identity": {
6976 "op": Op.IDENTITY,
6977 "operands": (1, 0),
6978 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6979 "types": TYPE_FIB,
6980 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006981 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08006982 "gather": {
6983 "op": Op.GATHER,
6984 # Only specify 'values' tensor here. 'indices' is generated in op building stage
6985 "operands": (1, 0),
6986 "rank": (3, 3),
6987 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
6988 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006989 "error_if_validators": (
6990 TosaErrorValidator.evWrongInputType,
6991 TosaErrorValidator.evWrongOutputType,
6992 TosaErrorValidator.evWrongInputList,
6993 TosaErrorValidator.evWrongOutputList,
6994 TosaErrorValidator.evWrongRank,
6995 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006996 },
6997 "scatter": {
6998 "op": Op.SCATTER,
6999 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007000 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08007001 "operands": (2, 0),
7002 "rank": (3, 3),
7003 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
7004 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007005 "error_if_validators": (
7006 TosaErrorValidator.evWrongInputType,
7007 TosaErrorValidator.evWrongOutputType,
7008 TosaErrorValidator.evWrongInputList,
7009 TosaErrorValidator.evWrongOutputList,
7010 TosaErrorValidator.evWrongRank,
7011 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007012 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007013 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08007014 "resize": {
7015 "op": Op.RESIZE,
7016 "operands": (1, 0),
7017 "rank": (4, 4),
7018 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
7019 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007020 "invalid_test_validators": (
7021 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
7022 TosaInvalidValidator.ivBadStride,
7023 ),
7024 "error_if_validators": (
7025 TosaErrorValidator.evMaxDimExceeded,
7026 TosaErrorValidator.evStrideSmallerEqualZero,
7027 TosaErrorValidator.evStrideLargerDimension,
7028 TosaErrorValidator.evStrideLargerEqualMax,
7029 TosaErrorValidator.evOffsetSmallerEqualMin,
7030 TosaErrorValidator.evOffsetLargerEqualMax,
7031 TosaErrorValidator.evShiftNotZero,
7032 TosaErrorValidator.evShiftSmallerOne,
7033 TosaErrorValidator.evShiftLargerEleven,
7034 TosaErrorValidator.evWrongInputType,
7035 TosaErrorValidator.evWrongOutputType,
7036 TosaErrorValidator.evWrongRank,
7037 TosaErrorValidator.evWrongInputList,
7038 TosaErrorValidator.evWrongOutputList,
7039 TosaErrorValidator.evBatchMismatch,
7040 TosaErrorValidator.evChannelMismatch,
7041 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007042 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007043 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08007044 "cast": {
7045 "op": Op.CAST,
7046 "operands": (1, 0),
7047 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
7048 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007049 "error_if_validators": (
7050 TosaErrorValidator.evWrongInputType,
7051 TosaErrorValidator.evWrongOutputType,
7052 TosaErrorValidator.evWrongInputList,
7053 TosaErrorValidator.evWrongOutputList,
7054 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007055 },
7056 "rescale": {
7057 "op": Op.RESCALE,
7058 "operands": (1, 0),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007059 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007060 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01007061 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007062 "error_if_validators": (
7063 TosaErrorValidator.evInputZeroPointNotZero,
7064 TosaErrorValidator.evOutputZeroPointNotZero,
7065 TosaErrorValidator.evScaleTrue,
7066 TosaErrorValidator.evScaleNotTrue,
7067 TosaErrorValidator.evWrongInputType,
7068 TosaErrorValidator.evWrongOutputType,
7069 TosaErrorValidator.evWrongRank,
7070 TosaErrorValidator.evWrongInputList,
7071 TosaErrorValidator.evWrongOutputList,
7072 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007073 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007074 # Custom
7075 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08007076 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07007077 # Two varients of cond_if, one that generates one of two constant tensors (no
7078 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
7079 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08007080 "cond_if_const": {
7081 "op": Op.COND_IF,
7082 "operands": (0, 2),
7083 "build_fcn": (
7084 build_cond_if_const,
7085 TosaTensorGen.tgBasic,
7086 TosaArgGen.agCondIf,
7087 ),
7088 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007089 "error_if_validators": (
7090 TosaErrorValidator.evOutputListThenGraphMismatch,
7091 TosaErrorValidator.evOutputListElseGraphMismatch,
7092 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007093 },
7094 "cond_if_binary": {
7095 "op": Op.COND_IF,
7096 "operands": (2, 0),
7097 "build_fcn": (
7098 build_cond_if_binary,
7099 TosaTensorGen.tgBasic,
7100 TosaArgGen.agCondIf,
7101 ),
Les Bell6040b4d2021-10-11 12:50:31 +01007102 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007103 "error_if_validators": (
7104 TosaErrorValidator.evInputListThenGraphMismatch,
7105 TosaErrorValidator.evInputListElseGraphMismatch,
7106 TosaErrorValidator.evOutputListThenGraphMismatch,
7107 TosaErrorValidator.evOutputListElseGraphMismatch,
7108 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007109 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007110 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08007111 "while_loop": {
7112 "op": Op.WHILE_LOOP,
7113 "operands": (0, 1),
7114 "build_fcn": (
7115 build_while_loop,
7116 TosaTensorGen.tgBasic,
7117 TosaArgGen.agWhileLoop,
7118 ),
7119 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007120 "error_if_validators": (
7121 TosaErrorValidator.evInputListOutputListMismatch,
7122 TosaErrorValidator.evInputListCondGraphMismatch,
7123 TosaErrorValidator.evInputListBodyGraphInputMismatch,
7124 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
7125 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
7126 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007127 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007128 }
7129
Kevin Cheng550ccc52021-03-03 11:21:43 -08007130
Eric Kunzee5e26762020-10-13 16:11:07 -07007131class OutputShaper:
7132 # Methods in this class compute the expected output shape and datatype
7133 # for common classes of operations
7134 def __init__(self):
7135 pass
7136
7137 # These methods return arguments that can be used for
7138 # creating a new output tensor
7139 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01007140 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
7141 if error_name != ErrorIf.RankMismatch:
7142 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08007143 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007144
7145 shape = []
7146 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007147 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07007148 shape.append(b.shape[i])
7149 else:
7150 shape.append(a.shape[i])
7151
Matthew Haddoneacff9a2021-09-24 14:42:13 +01007152 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007153 all_dtypes = [
7154 DType.INT8,
7155 DType.INT16,
7156 DType.INT32,
7157 DType.INT48,
7158 DType.FLOAT,
7159 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01007160 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7161 outputDType = rng.choice(wrong_dtypes)
7162 else:
7163 outputDType = a.dtype
7164
7165 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007166
7167 @staticmethod
7168 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08007169 assert len(a.shape) == len(b.shape)
7170 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007171
7172 shape = []
7173 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08007174 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07007175 shape.append(a.shape[i])
7176
Kevin Cheng550ccc52021-03-03 11:21:43 -08007177 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007178
7179 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01007180 def unaryOp(ser, rng, a, error_name=None):
7181 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007182 all_dtypes = [
7183 DType.INT8,
7184 DType.INT16,
7185 DType.INT32,
7186 DType.INT48,
7187 DType.FLOAT,
7188 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01007189 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7190 outputDType = rng.choice(wrong_dtypes)
7191 else:
7192 outputDType = a.dtype
7193
7194 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007195
7196 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007197 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00007198 if error_name != ErrorIf.RankMismatch:
7199 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08007200 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007201
7202 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00007203 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007204 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00007205 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
7206 else:
7207 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07007208
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007209 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007210 all_dtypes = [
7211 DType.INT8,
7212 DType.INT16,
7213 DType.INT32,
7214 DType.INT48,
7215 DType.FLOAT,
7216 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007217 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7218 outputDType = rng.choice(wrong_dtypes)
7219 else:
7220 outputDType = a.dtype
7221
7222 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007223
7224 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007225 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00007226 if error_name != ErrorIf.RankMismatch:
7227 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08007228 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007229
7230 # Do broadcast
7231 shape = []
7232 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08007233 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07007234 shape.append(b.shape[i])
7235 else:
7236 shape.append(a.shape[i])
7237
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007238 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007239 wrong_dtypes = [
7240 DType.INT8,
7241 DType.INT16,
7242 DType.INT32,
7243 DType.INT48,
7244 DType.FLOAT,
7245 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007246 outputDType = rng.choice(wrong_dtypes)
7247 else:
7248 outputDType = DType.BOOL
7249
7250 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007251
7252 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01007253 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007254 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007255 if error_name not in [
7256 ErrorIf.AxisSmallerZero,
7257 ErrorIf.AxisLargerRank,
7258 ErrorIf.ShapeOfAxisNotOne,
7259 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01007260 shape[axis] = 1
7261 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
7262 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07007263
Matthew Haddond6ce7252021-09-29 15:35:44 +01007264 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007265 all_dtypes = [
7266 DType.INT8,
7267 DType.INT16,
7268 DType.INT32,
7269 DType.INT48,
7270 DType.FLOAT,
7271 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01007272 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7273 outputDType = rng.choice(wrong_dtypes)
7274 else:
7275 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007276
Matthew Haddond6ce7252021-09-29 15:35:44 +01007277 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007278
7279 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007280 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007281 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007282
7283 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
7284 del shape[axis]
7285
7286 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
7287 remove = rng.choice([True, False])
7288 if remove and len(shape) > 1:
7289 del shape[0]
7290 else:
7291 shape.append(1)
7292 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
7293 for i in range(len(shape)):
7294 shape[i] = shape[i] + rng.integers(1, 10)
7295
7296 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007297 all_dtypes = [
7298 DType.INT8,
7299 DType.INT16,
7300 DType.INT32,
7301 DType.INT48,
7302 DType.FLOAT,
7303 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007304 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
7305 outputDType = rng.choice(wrong_dtypes)
7306 else:
7307 outputDType = DType.INT32
7308
7309 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007310
7311 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00007312 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007313
7314 # IFM: NHWC
7315 # Filter: OHWI
7316 # OFM: NHWC
7317
7318 if len(padding) == 2:
7319 # Expand padding to 4 parameters in the case of transpose_conv2d
7320 # From H,W to T,B,L,R
7321 padding = [padding[0], padding[0], padding[1], padding[1]]
7322
Kevin Cheng550ccc52021-03-03 11:21:43 -08007323 h = (
7324 ifm.shape[1]
7325 - filter.shape[1]
7326 - (filter.shape[1] - 1) * (dilations[0] - 1)
7327 + padding[0]
7328 + padding[1]
7329 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07007330
Kevin Cheng550ccc52021-03-03 11:21:43 -08007331 w = (
7332 ifm.shape[2]
7333 - filter.shape[2]
7334 - (filter.shape[2] - 1) * (dilations[1] - 1)
7335 + padding[2]
7336 + padding[3]
7337 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07007338
Les Bell0e027d42021-11-09 14:42:14 +00007339 # Avoid illegal dimensions, which can be generated in error_if tests
7340 h = max(h, 1)
7341 w = max(w, 1)
7342
Eric Kunzee5e26762020-10-13 16:11:07 -07007343 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
7344
Kevin Cheng3a478572021-01-22 17:21:02 -08007345 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007346 out_dtype = DType.INT32
7347 elif ifm.dtype == DType.INT16:
7348 out_dtype = DType.INT48
7349 elif ifm.dtype == DType.FLOAT:
7350 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00007351 elif error_name == ErrorIf.WrongInputType:
7352 # Pick some potentially correct output dtype if input type is incorrect
7353 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007354 else:
Les Bell0e027d42021-11-09 14:42:14 +00007355 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
7356
7357 if error_name == ErrorIf.WrongOutputType:
7358 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
7359 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07007360
Kevin Cheng550ccc52021-03-03 11:21:43 -08007361 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007362
7363 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00007364 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07007365
7366 # IFM: NDHWC
7367 # Filter: ODHWI
7368 # OFM: NDHWC
7369
7370 d = (
7371 ifm.shape[1]
7372 - filter.shape[1]
7373 - (filter.shape[1] - 1) * (dilations[0] - 1)
7374 + padding[0]
7375 + padding[1]
7376 ) // strides[0] + 1
7377
7378 h = (
7379 ifm.shape[2]
7380 - filter.shape[2]
7381 - (filter.shape[2] - 1) * (dilations[1] - 1)
7382 + padding[2]
7383 + padding[3]
7384 ) // strides[1] + 1
7385
7386 w = (
7387 ifm.shape[3]
7388 - filter.shape[3]
7389 - (filter.shape[3] - 1) * (dilations[2] - 1)
7390 + padding[4]
7391 + padding[5]
7392 ) // strides[2] + 1
7393
Les Bell0e027d42021-11-09 14:42:14 +00007394 # Avoid illegal dimensions, which can be generated in error_if tests
7395 d = max(d, 1)
7396 h = max(h, 1)
7397 w = max(w, 1)
7398
Kevin Cheng1533b852021-09-01 12:51:58 -07007399 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
7400
7401 if ifm.dtype == DType.INT8:
7402 out_dtype = DType.INT32
7403 elif ifm.dtype == DType.INT16:
7404 out_dtype = DType.INT48
7405 elif ifm.dtype == DType.FLOAT:
7406 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00007407 elif error_name == ErrorIf.WrongInputType:
7408 # Pick some potentially correct output dtype if input type is incorrect
7409 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07007410 else:
Les Bell0e027d42021-11-09 14:42:14 +00007411 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
7412
7413 if error_name == ErrorIf.WrongOutputType:
7414 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
7415 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07007416
7417 return ser.addOutput(ofm_shape, out_dtype)
7418
7419 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007420 def depthwiseConv2dOp(
7421 ser, rng, ifm, filter, strides, padding, dilations, error_name=None
7422 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07007423 # IFM: NHWC
7424 # Filter: HWCM
7425 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08007426 h = (
7427 ifm.shape[1]
7428 - filter.shape[0]
7429 - (filter.shape[0] - 1) * (dilations[0] - 1)
7430 + padding[0]
7431 + padding[1]
7432 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07007433
Kevin Cheng550ccc52021-03-03 11:21:43 -08007434 w = (
7435 ifm.shape[2]
7436 - filter.shape[1]
7437 - (filter.shape[1] - 1) * (dilations[1] - 1)
7438 + padding[2]
7439 + padding[3]
7440 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07007441
Les Bell0e027d42021-11-09 14:42:14 +00007442 # Avoid illegal dimensions, which can be generated in error_if tests
7443 h = max(h, 1)
7444 w = max(w, 1)
7445
Eric Kunzee5e26762020-10-13 16:11:07 -07007446 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
7447
Kevin Cheng3a478572021-01-22 17:21:02 -08007448 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007449 out_dtype = DType.INT32
7450 elif ifm.dtype == DType.INT16:
7451 out_dtype = DType.INT48
7452 elif ifm.dtype == DType.FLOAT:
7453 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00007454 elif error_name == ErrorIf.WrongInputType:
7455 # Pick some potentially correct output dtype if input type is incorrect
7456 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007457 else:
Les Bell0e027d42021-11-09 14:42:14 +00007458 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
7459
7460 if error_name == ErrorIf.WrongOutputType:
7461 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
7462 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07007463
Kevin Cheng550ccc52021-03-03 11:21:43 -08007464 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007465
7466 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007467 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007468 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007469 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007470 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007471 h = 1
7472 w = 1
7473 else:
7474 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
7475 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
7476
7477 if error_name == ErrorIf.PoolingOutputShapeMismatch:
7478 choices = [1, 2, 3, 4, 5]
7479 h = h + rng.choice(choices)
7480 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07007481
Eric Kunzee5e26762020-10-13 16:11:07 -07007482 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007483
7484 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007485 all_dtypes = [
7486 DType.INT8,
7487 DType.INT16,
7488 DType.INT32,
7489 DType.INT48,
7490 DType.FLOAT,
7491 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007492 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
7493 outputDType = rng.choice(wrong_dtypes)
7494 else:
7495 outputDType = ifm.dtype
7496
7497 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007498
7499 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007500 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007501 # input: N, IC
7502 # filter: OC, IC
7503 # output: N, OC
7504
7505 output_shape = [input.shape[0], filter.shape[0]]
7506
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007507 if error_name == ErrorIf.WrongOutputType:
7508 if input.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007509 incorrect_types = (
7510 DType.INT4,
7511 DType.INT8,
7512 DType.INT16,
7513 DType.INT48,
7514 DType.FLOAT,
7515 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007516 elif input.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007517 incorrect_types = (
7518 DType.INT4,
7519 DType.INT8,
7520 DType.INT16,
7521 DType.INT32,
7522 DType.FLOAT,
7523 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007524 elif input.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007525 incorrect_types = (
7526 DType.INT4,
7527 DType.INT8,
7528 DType.INT16,
7529 DType.INT32,
7530 DType.INT48,
7531 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007532 out_dtype = rng.choice(a=incorrect_types)
7533 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007534 out_dtype = DType.INT32
7535 elif input.dtype == DType.INT16:
7536 out_dtype = DType.INT48
7537 elif input.dtype == DType.FLOAT:
7538 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007539 elif error_name == ErrorIf.WrongInputType:
7540 # Pick some potentially correct output dtype if input type is incorrect
7541 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007542 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08007543 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07007544
Kevin Cheng550ccc52021-03-03 11:21:43 -08007545 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007546
7547 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007548 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07007549 # a: N, H, C
7550 # b: N, C, W
7551 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07007552
Kevin Cheng2d60f002021-06-09 14:18:32 -07007553 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07007554
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007555 if error_name == ErrorIf.WrongOutputType:
7556 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007557 incorrect_types = (
7558 DType.INT4,
7559 DType.INT8,
7560 DType.INT16,
7561 DType.INT48,
7562 DType.FLOAT,
7563 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007564 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007565 incorrect_types = (
7566 DType.INT4,
7567 DType.INT8,
7568 DType.INT16,
7569 DType.INT32,
7570 DType.FLOAT,
7571 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007572 elif a.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007573 incorrect_types = (
7574 DType.INT4,
7575 DType.INT8,
7576 DType.INT16,
7577 DType.INT32,
7578 DType.INT48,
7579 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007580 out_dtype = rng.choice(a=incorrect_types)
7581 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007582 out_dtype = DType.INT32
7583 elif a.dtype == DType.INT16:
7584 out_dtype = DType.INT48
7585 elif a.dtype == DType.FLOAT:
7586 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007587 elif error_name == ErrorIf.WrongInputType:
7588 # Pick some potentially correct output dtype if input type is incorrect
7589 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007590 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007591 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07007592
Kevin Cheng550ccc52021-03-03 11:21:43 -08007593 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007594
7595 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007596 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01007597 input1 = a[0]
7598 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07007599
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007600 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01007601 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007602 if not (
7603 # unable to concat tensors of different ranks
7604 error_name == ErrorIf.ConcatInputRankMismatch
7605 # unable to concat tensors along an invalid axis
7606 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007607 ):
7608 for tensor in remaining_inputs:
7609 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07007610
Matthew Haddon01c359d2021-10-15 16:30:48 +01007611 if error_name == ErrorIf.ConcatShapeSumMismatch:
7612 output_shape[axis] += rng.integers(5, 10)
7613
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007614 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007615 all_dtypes = {
7616 DType.INT8,
7617 DType.INT16,
7618 DType.INT32,
7619 DType.INT48,
7620 DType.FLOAT,
7621 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007622 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
7623 outputDType = rng.choice(wrong_dtypes)
7624 else:
7625 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01007626
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007627 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007628
7629 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01007630 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007631
7632 output_shape = a.shape.copy()
7633
7634 for i in range(len(output_shape)):
7635 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
7636
Matthew Haddone807aae2021-10-11 18:12:58 +01007637 # Fix negative output shape if error_if test causes it
7638 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
7639 output_shape = [i if i >= 1 else 1 for i in output_shape]
7640
7641 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007642 all_dtypes = [
7643 DType.INT8,
7644 DType.INT16,
7645 DType.INT32,
7646 DType.INT48,
7647 DType.FLOAT,
7648 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01007649 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7650 outputDType = rng.choice(wrong_dtypes)
7651 else:
7652 outputDType = a.dtype
7653
7654 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007655
7656 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01007657 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007658 output_shape = shape.copy()
7659
7660 totalElements = 1
7661 for i in a.shape:
7662 totalElements *= i
7663
7664 # If there are any -1 elements, figure out what that dimension must be
7665 totalOutputElements = 1
7666 for i in output_shape:
7667 if i != -1:
7668 totalOutputElements *= i
7669
7670 # And fill it in
7671 for i in range(len(output_shape)):
7672 if output_shape[i] == -1:
7673 output_shape[i] = totalElements // totalOutputElements
7674
Matthew Haddone807aae2021-10-11 18:12:58 +01007675 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
7676 for i in range(len(output_shape)):
7677 output_shape[i] = output_shape[i] + rng.integers(1, 10)
7678
7679 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007680 all_dtypes = [
7681 DType.INT8,
7682 DType.INT16,
7683 DType.INT32,
7684 DType.INT48,
7685 DType.FLOAT,
7686 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01007687 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7688 outputDType = rng.choice(wrong_dtypes)
7689 else:
7690 outputDType = a.dtype
7691
7692 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007693
7694 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01007695 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007696
Matthew Haddone807aae2021-10-11 18:12:58 +01007697 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007698 all_dtypes = [
7699 DType.INT8,
7700 DType.INT16,
7701 DType.INT32,
7702 DType.INT48,
7703 DType.FLOAT,
7704 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01007705 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7706 outputDType = rng.choice(wrong_dtypes)
7707 else:
7708 outputDType = a.dtype
7709
7710 if error_name == ErrorIf.SizeOutputShapeMismatch:
7711 output_shape = size.copy()
7712 for index in range(len(output_shape)):
7713 if output_shape[index] <= 2:
7714 output_shape[index] = output_shape[index] + rng.choice([1, 2])
7715 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007716 output_shape[index] = output_shape[index] + rng.choice(
7717 [-2, -1, 1, 2]
7718 )
Matthew Haddone807aae2021-10-11 18:12:58 +01007719 else:
7720 output_shape = size.copy()
7721
7722 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007723
7724 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007725 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007726
7727 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08007728 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07007729
7730 for i in range(len(output_shape)):
7731 output_shape[i] = a.shape[i] * multiples[i]
7732
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007733 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007734 all_dtypes = [
7735 DType.INT8,
7736 DType.INT16,
7737 DType.INT32,
7738 DType.INT48,
7739 DType.FLOAT,
7740 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007741 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7742 outputDType = rng.choice(wrong_dtypes)
7743 else:
7744 outputDType = a.dtype
7745
7746 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007747
7748 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01007749 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007750 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01007751
Kevin Cheng550ccc52021-03-03 11:21:43 -08007752 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07007753
Matthew Haddone807aae2021-10-11 18:12:58 +01007754 if error_name == ErrorIf.IndexOutsideBounds:
7755 for i in range(len(output_shape)):
7756 output_shape[i] = a.shape[0]
7757 else:
7758 for i in range(len(output_shape)):
7759 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07007760
Matthew Haddone807aae2021-10-11 18:12:58 +01007761 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007762 all_dtypes = [
7763 DType.INT8,
7764 DType.INT16,
7765 DType.INT32,
7766 DType.INT48,
7767 DType.FLOAT,
7768 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01007769 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7770 outputDType = rng.choice(wrong_dtypes)
7771 else:
7772 outputDType = a.dtype
7773
7774 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007775
7776 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007777 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00007778 if error_name != ErrorIf.WrongRank:
7779 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08007780 assert len(indices.shape) == 2
7781 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07007782
Kevin Cheng77d0f762020-11-24 10:26:32 -08007783 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
7784
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007785 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007786 all_dtypes = [
7787 DType.INT8,
7788 DType.INT16,
7789 DType.INT32,
7790 DType.INT48,
7791 DType.FLOAT,
7792 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007793 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
7794 outputDType = rng.choice(wrong_dtypes)
7795 else:
7796 outputDType = values.dtype
7797
7798 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08007799
7800 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007801 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00007802 if error_name != ErrorIf.WrongRank:
7803 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08007804 assert len(indices.shape) == 2
7805 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08007806 assert values_in.shape[0] == indices.shape[0] # N
7807 assert input.shape[1] == indices.shape[1] # W
7808 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08007809
7810 output_shape = values_in.shape
7811
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007812 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007813 all_dtypes = [
7814 DType.INT8,
7815 DType.INT16,
7816 DType.INT32,
7817 DType.INT48,
7818 DType.FLOAT,
7819 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007820 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
7821 outputDType = rng.choice(wrong_dtypes)
7822 else:
7823 outputDType = values_in.dtype
7824
7825 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007826
7827 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007828 def tableOp(ser, rng, input, error_name=None):
7829 # Same shape as the input, dtype dependent on input dtype
7830 if error_name != ErrorIf.WrongInputType:
7831 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00007832 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007833 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007834 wrong_dtypes = [
7835 DType.INT8,
7836 DType.INT16,
7837 DType.INT32,
7838 DType.INT48,
7839 DType.FLOAT,
7840 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007841 wrong_dtypes.remove(output_dtype)
7842 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01007843 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007844
7845 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08007846 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007847 serializer,
7848 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08007849 input,
7850 mode,
7851 stride,
7852 offset,
7853 shift,
7854 stride_fp,
7855 offset_fp,
7856 output_dims,
7857 input_dtype,
7858 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007859 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08007860 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01007861 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007862 output_dims = [
7863 input.shape[0],
7864 output_dims[0],
7865 output_dims[0],
7866 input.shape[0],
7867 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01007868 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007869 if error_name == ErrorIf.BatchMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007870 output_dims = [
7871 input.shape[0] + rng.integers(1, 10),
7872 output_dims[0],
7873 output_dims[1],
7874 input.shape[3],
7875 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007876 elif error_name == ErrorIf.ChannelMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007877 output_dims = [
7878 input.shape[0],
7879 output_dims[0],
7880 output_dims[1],
7881 input.shape[3] + rng.integers(1, 10),
7882 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007883 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007884 output_dims = [
7885 input.shape[0],
7886 output_dims[0],
7887 output_dims[1],
7888 input.shape[3],
7889 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07007890
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007891 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007892
7893 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007894 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08007895 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007896
7897 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00007898 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Kevin Cheng3a478572021-01-22 17:21:02 -08007899 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007900 out_dtype = DType.INT32
7901 elif ifm.dtype == DType.INT16:
7902 out_dtype = DType.INT48
7903 elif ifm.dtype == DType.FLOAT:
7904 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00007905 elif error_name == ErrorIf.WrongInputType:
7906 # Pick some potentially correct output dtype if input type is incorrect
7907 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007908 else:
Les Bell0e027d42021-11-09 14:42:14 +00007909 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
7910
7911 if error_name == ErrorIf.WrongOutputType:
7912 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
7913 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07007914
Kevin Cheng550ccc52021-03-03 11:21:43 -08007915 return ser.addOutput(output_shape, out_dtype)