blob: 5f613f0c7d7a2469baadc2fbba6c181af87af8d3 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnsona6185572021-06-21 15:55:35 +01003import itertools
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import math
5import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01006from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07007
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00008import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00009import serializer.tosa_serializer as ts
Jeremy Johnson2ec34942021-12-14 16:34:05 +000010from generator.tosa_error_if import ErrorIf
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011from serializer.tosa_serializer import DTypeNames
Les Bell0e027d42021-11-09 14:42:14 +000012from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
Eric Kunzee5e26762020-10-13 16:11:07 -070015
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000016# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Matthew Haddon630c17c2021-10-14 15:05:41 +010019
Les Bell0e027d42021-11-09 14:42:14 +000020def valueToName(item, value):
21 """Get the name of an attribute with the given value.
22
23 This convenience function is needed to print meaningful names for
24 the values of the tosa.Op.Op and tosa.DType.DType classes.
25 This would not be necessary if they were subclasses of Enum, or
26 IntEnum, which, sadly, they are not.
27
28 Args:
29 item: The class, or object, to find the value in
30 value: The value to find
31
32 Example, to get the name of a DType value:
33
34 name = valueToName(DType, DType.INT8) # returns 'INT8'
35 name = valueToName(DType, 4) # returns 'INT8'
36
37 Returns:
38 The name of the first attribute found with a matching value,
39
40 Raises:
41 ValueError if the value is not found
42 """
43 for attr in dir(item):
44 if getattr(item, attr) == value:
45 return attr
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000046 raise ValueError(f"value ({value}) not found")
47
Les Bell0e027d42021-11-09 14:42:14 +000048
49def allDTypes(*, excludes=None):
50 """Get a set of all DType values, optionally excluding some values.
51
52 This convenience function is needed to provide a sequence of DType values.
53 This would be much easier if DType was a subclass of Enum, or IntEnum,
54 as we could then iterate over the values directly, instead of using
55 dir() to find the attributes and then check if they are what we want.
56
57 Args:
58 excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
59
60 Returns:
61 A set of DType values
62 """
63 excludes = () if not excludes else excludes
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000064 return {
65 getattr(DType, t)
66 for t in dir(DType)
67 if not callable(getattr(DType, t))
68 and not t.startswith("__")
69 and getattr(DType, t) not in excludes
70 }
71
Les Bell0e027d42021-11-09 14:42:14 +000072
73def usableDTypes(*, excludes=None):
74 """Get a set of usable DType values, optionally excluding some values.
75
76 Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
77 specified by the caller, as the serializer lib does not support them.
78 If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
79
80 Args:
81 excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
82
83 Returns:
84 A set of DType values
85 """
86 omit = {DType.UNKNOWN, DType.UINT8}
87 omit.update(excludes if excludes else ())
88 return allDTypes(excludes=omit)
89
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000090
Matthew Haddon630c17c2021-10-14 15:05:41 +010091def product(shape):
92 value = 1
93 for n in shape:
94 value *= n
95 return value
96
Les Bell0e027d42021-11-09 14:42:14 +000097
Eric Kunzee5e26762020-10-13 16:11:07 -070098class TosaQuantGen:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000099 """QuantizedInfo random generator helper functions.
100
101 Specify with 'qgen': in the operator defintion.
102 """
Kevin Cheng550ccc52021-03-03 11:21:43 -0800103
Eric Kunzee5e26762020-10-13 16:11:07 -0700104 def __init__(self):
105 pass
106
107 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100108 def getQinfo(testGen, dtype, error_name=None):
109
Les Bell30e46802021-07-23 09:43:31 +0100110 if dtype == DType.INT8:
111 return testGen.randInt(-128, 128)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100112 elif dtype == DType.UINT8:
Les Bell30e46802021-07-23 09:43:31 +0100113 return testGen.randInt(0, 256)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000114 elif error_name in [
115 ErrorIf.InputZeroPointNotZero,
116 ErrorIf.WeightZeroPointNotZero,
117 ErrorIf.OutputZeroPointNotZero,
118 ]:
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100119 zero_point = testGen.randInt(-128, 128)
120 if zero_point == 0:
121 zero_point = 1
122 return zero_point
Les Bell30e46802021-07-23 09:43:31 +0100123 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
125 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100126 def qgUnary(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100128 if error_name == ErrorIf.InputZeroPointNotZero:
129 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000130 TosaQuantGen.getQinfo(testGen, dtype, error_name),
131 TosaQuantGen.getQinfo(testGen, dtype),
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100132 )
133 elif error_name == ErrorIf.OutputZeroPointNotZero:
134 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000135 TosaQuantGen.getQinfo(testGen, dtype),
136 TosaQuantGen.getQinfo(testGen, dtype, error_name),
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100137 )
138 else:
139 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000140 TosaQuantGen.getQinfo(testGen, dtype),
141 TosaQuantGen.getQinfo(testGen, dtype),
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100142 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 return qinfo
144
145 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100146 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700147 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100148 if isinstance(dtype_or_dtypeList, list):
149 # a list of [input, weights, accumulator] dtypes
150 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 else:
Les Bell30e46802021-07-23 09:43:31 +0100152 # an int, [input, weights, accumulator] dtypes are the same
153 dtypeList = [dtype_or_dtypeList] * 3
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100154
155 if error_name == ErrorIf.InputZeroPointNotZero:
156 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
157 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
158 elif error_name == ErrorIf.WeightZeroPointNotZero:
159 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
160 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
161 else:
162 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
163 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
164
Les Bell30e46802021-07-23 09:43:31 +0100165 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 return qinfo
167
168 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100169 def qgMatmul(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700170 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100171 if error_name == ErrorIf.InputZeroPointNotZero:
172 qinfo.MatMulQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000173 TosaQuantGen.getQinfo(testGen, dtype, error_name),
174 TosaQuantGen.getQinfo(testGen, dtype, error_name),
175 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100176 else:
177 qinfo.MatMulQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000178 TosaQuantGen.getQinfo(testGen, dtype),
179 TosaQuantGen.getQinfo(testGen, dtype),
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100180 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700181 return qinfo
182
183 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100184 def qgPad(testGen, op, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 qinfo = ts.TosaSerializerQuantInfo()
Matthew Haddone807aae2021-10-11 18:12:58 +0100186 if error_name == ErrorIf.InputZeroPointNotZero:
187 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
188 else:
189 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700190 return qinfo
191
192 @staticmethod
193 def computeMultiplierAndShift(scaleFp, scale32):
194 # Derived from computeMultiplierAndShiftTosaScale32
195 # Provide a floating-point scaling factor and the scale32 parameter
196 # to compute the multiplier and shift
197
198 if scale32:
199 scaleBits = 31
200 else:
201 scaleBits = 15
202
203 m, shift = math.frexp(scaleFp)
204
205 if scaleFp < 0.0:
206 m = -m
207
208 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800209 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
211 if multiplier == (1 << scaleBits):
212 multiplier = multiplier // 2
213 shift = shift + 1
214
215 shift = (-shift) + scaleBits
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000216 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
217 # scaleFp, scaleBits, m, multiplier, shift))
Matthew Haddonb724efc2021-08-25 16:40:29 +0100218
219 # Adjust multiplier such that shift is in allowed value range.
220 if shift == 0:
221 multiplier = multiplier // 4
222 shift = shift + 2
223 elif shift == 1:
224 multiplier = multiplier // 2
225 shift = shift + 1
226 elif shift == 63:
227 multiplier = multiplier * 2
228 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
Kevin Cheng550ccc52021-03-03 11:21:43 -0800230 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100231 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700232
233 return multiplier, shift
234
235
Kevin Cheng550ccc52021-03-03 11:21:43 -0800236class TosaTensorGen:
237 """Tensor generators create a shape list for the placeholder and const tensor
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000238 data operands for the operator.
239
240 The actual random data is generated separately for each test.
241 """
Kevin Cheng550ccc52021-03-03 11:21:43 -0800242
Eric Kunzee5e26762020-10-13 16:11:07 -0700243 def __init__(self):
244 pass
245
246 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100247 def tgBasic(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800248 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700249 shape = testGen.makeShape(rank)
250
Matthew Haddon630c17c2021-10-14 15:05:41 +0100251 # Constrict the overall size of the shape when creating ERROR_IF tests
252 if error_name:
253 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Matthew Haddonc2025212021-10-08 21:21:05 +0100254
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 shape_list = []
256 for i in range(pl + const):
257 shape_list.append(shape.copy())
258
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100259 if error_name == ErrorIf.RankMismatch:
260 if rank == 1 and i != 1:
261 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
262 elif i != 1:
263 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
264
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 return shape_list
266
267 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100268 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800269 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700270
Matthew Haddon848efb42021-09-09 12:30:53 +0100271 if error_name != ErrorIf.WrongRank:
272 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700273
274 shape = testGen.makeShape(rank)
275
276 # Constrict the batch size?
277 if testGen.args.max_batch_size:
278 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100279
Matthew Haddon630c17c2021-10-14 15:05:41 +0100280 # Constrict the overall size of the shape when creating ERROR_IF tests
Jeremy Johnson27cf5432021-11-16 11:12:17 +0000281 if error_name and error_name != ErrorIf.MaxDimExceeded:
Matthew Haddon630c17c2021-10-14 15:05:41 +0100282 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
Eric Kunzee5e26762020-10-13 16:11:07 -0700283
284 shape_list = []
285 for i in range(pl + const):
286 shape_list.append(shape.copy())
287
288 return shape_list
289
290 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100291 def tgScatter(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800292 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800293
Kevin Cheng550ccc52021-03-03 11:21:43 -0800294 assert pl == 2
295 assert const == 0
Jeremy Johnson3ca02a72021-11-18 12:18:39 +0000296 if error_name != ErrorIf.WrongRank:
297 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800298
299 values_in_shape = testGen.makeShape(rank)
300
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100301 # ignore max batch size if target shape is set
302 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800303 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
304
Kevin Cheng550ccc52021-03-03 11:21:43 -0800305 W = testGen.randInt(
306 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
307 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100308 # Constrict W if one dimension is too large to keep tensor size reasonable
309 if max(values_in_shape) > 5000:
310 W = testGen.randInt(0, 16)
311
Kevin Cheng77d0f762020-11-24 10:26:32 -0800312 input_shape = [values_in_shape[0], W, values_in_shape[2]]
313
314 shape_list = []
315 shape_list.append(values_in_shape.copy())
316 shape_list.append(input_shape.copy())
317
318 return shape_list
319
320 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100321 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700322 shape = testGen.makeShape(rank)
323
Kevin Cheng550ccc52021-03-03 11:21:43 -0800324 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700325
326 shape_list = []
327
328 # Choose one of the inputs to broadcast
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000329 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000330 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Eric Kunzee5e26762020-10-13 16:11:07 -0700331 for i in range(pl + const):
332 shape_bcast = shape.copy()
333
334 # If the chosen input, pick a random index to broadcast
335 if i == bcast_idx:
336 fuzz_idx = testGen.randInt(0, rank)
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000337 if error_name == ErrorIf.DimensionMismatch:
338 shape_bcast[fuzz_idx] += 1
339 elif error_name == ErrorIf.RankMismatch:
340 # Add one rank to the shape (or more for rank of 1)
341 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000342 shape_bcast = np.concatenate(
343 (shape_bcast, testGen.makeShape(extra_ranks))
344 )
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +0000345 if rank != 1:
346 # Either keep the extra rank, or remove it
347 new_len = testGen.rng.choice([-2, len(shape_bcast)])
348 shape_bcast = shape_bcast[:new_len]
349 else:
350 shape_bcast[fuzz_idx] = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700351
352 shape_list.append(shape_bcast)
353
354 return shape_list
355
356 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100357 def tgConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800358 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
Les Bell0e027d42021-11-09 14:42:14 +0000360 if error_name != ErrorIf.WrongRank:
361 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700362
363 # IFM dimensions are NHWC
364 ifm_shape = testGen.makeShape(rank)
365
366 # Constrict the batch size?
367 if testGen.args.max_batch_size:
368 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
369
Les Bell0e027d42021-11-09 14:42:14 +0000370 # Constrict the overall size of the shape when creating ERROR_IF tests
371 if error_name:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000372 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
373 ifm_shape, max_dim=24, max_items=10000
374 )
Les Bell0e027d42021-11-09 14:42:14 +0000375
Eric Kunzee5e26762020-10-13 16:11:07 -0700376 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800377 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700378
379 # Generate a random OFM depth
380 ofm_depth = testGen.makeShape(1)[0]
381
382 # The filter dimensions are OHWI
383 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
384
385 # The bias is OC
386 bias_shape = np.asarray([ofm_depth])
387
388 return [ifm_shape, filter_shape, bias_shape]
389
390 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100391 def tgConv3D(testGen, op, rank, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -0700392 pl, const = op["operands"]
393
Les Bell0e027d42021-11-09 14:42:14 +0000394 if error_name != ErrorIf.WrongRank:
395 assert rank == 5
Kevin Cheng1533b852021-09-01 12:51:58 -0700396
397 # IFM dimensions are NDHWC
398 ifm_shape = testGen.makeShape(rank)
399
400 # Constrict the batch size?
401 if testGen.args.max_batch_size:
402 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
403
Les Bell0e027d42021-11-09 14:42:14 +0000404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
Les Bell0e027d42021-11-09 14:42:14 +0000409
Kevin Cheng1533b852021-09-01 12:51:58 -0700410 # Get the filter depth/height/width from the operator parameters
411 filter_dhw = op["filter"]
412
413 # Generate a random OFM channel
414 ofm_channel = testGen.makeShape(1)[0]
415
416 # The filter dimensions are ODHWI
417 filter_shape = np.asarray(
418 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
419 )
420
421 # The bias is OC
422 bias_shape = np.asarray([ofm_channel])
423
424 return [ifm_shape, filter_shape, bias_shape]
425
426 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100427 def tgTransposeConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700429
Les Bell0e027d42021-11-09 14:42:14 +0000430 if error_name != ErrorIf.WrongRank:
431 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700432
433 # IFM dimensions are NHWC
434 ifm_shape = testGen.makeShape(rank)
435
436 # Constrict the batch size?
437 if testGen.args.max_batch_size:
438 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
439
Les Bell0e027d42021-11-09 14:42:14 +0000440 # Constrict the overall size of the shape when creating ERROR_IF tests
441 if error_name:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000442 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
443 ifm_shape, max_dim=24, max_items=10000
444 )
Les Bell0e027d42021-11-09 14:42:14 +0000445
Eric Kunzee5e26762020-10-13 16:11:07 -0700446 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800447 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700448
449 # Generate a random OFM depth
450 ofm_depth = testGen.makeShape(1)[0]
451
452 # The filter dimensions are OHWI
453 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
454
Kevin Cheng989cb052021-04-28 16:29:44 -0700455 # The bias is OC
456 bias_shape = np.asarray([ofm_depth])
457
458 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700459
460 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100461 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800462 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700463
Les Bell0e027d42021-11-09 14:42:14 +0000464 if error_name != ErrorIf.WrongRank:
465 assert rank == 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800466 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700467
468 # IFM dimensions are NHWC
469 ifm_shape = testGen.makeShape(rank)
470
471 # Constrict the batch size?
472 if testGen.args.max_batch_size:
473 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
474
Les Bell0e027d42021-11-09 14:42:14 +0000475 # Constrict the overall size of the shape when creating ERROR_IF tests
476 if error_name:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000477 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
478 ifm_shape, max_dim=24, max_items=10000
479 )
Les Bell0e027d42021-11-09 14:42:14 +0000480
Eric Kunzee5e26762020-10-13 16:11:07 -0700481 # Get the filter height/width from the operator parameters
482 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800483 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700484
485 # Generate a random OFM depth, but don't let it get too big because
486 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800487 filter_m = (
488 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
489 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700490
491 # The filter dimensions are HWCM
492 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
493
494 # The bias is M * C
495 bias_shape = np.asarray([ifm_shape[3] * filter_m])
496
497 return [ifm_shape, filter_shape, bias_shape]
498
499 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100500 def tgFullyConnected(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800501 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700502
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100503 if error_name != ErrorIf.WrongRank:
504 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
506 input_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100507
Matthew Haddon630c17c2021-10-14 15:05:41 +0100508 # Constrict the overall size of the shape when creating ERROR_IF tests
509 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000510 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100511
Kevin Chengacb550f2021-06-29 15:32:19 -0700512 filter_oc = testGen.rng.integers(
513 low=testGen.args.tensor_shape_range[0],
514 high=testGen.args.tensor_shape_range[1],
515 size=1,
516 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 filter_shape = np.asarray([filter_oc, input_shape[1]])
518
519 bias_shape = np.asarray([filter_oc])
520
521 return [input_shape, filter_shape, bias_shape]
522
523 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100524 def tgMatmul(testGen, op, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800525 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527 if error_name != ErrorIf.WrongRank:
528 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800529 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700530
531 a_shape = testGen.makeShape(rank)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100532
Matthew Haddon630c17c2021-10-14 15:05:41 +0100533 # Constrict the overall size of the shape when creating ERROR_IF tests
534 if error_name:
Les Bell0e027d42021-11-09 14:42:14 +0000535 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100536
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100537 # Get a random number for b_oc even if target shape is defined
538 b_oc = np.int32(
539 testGen.rng.integers(
540 low=testGen.args.tensor_shape_range[0],
541 high=testGen.args.tensor_shape_range[1],
542 size=1,
543 )
544 )[0]
545 # If N or H is large let b_oc be 1 to reduce output tensor size
546 if max(a_shape) > 1000:
547 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100549 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700550 return [a_shape, b_shape]
551
Matthew Haddon818ab902021-07-27 09:12:49 +0100552 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100553 def tgConcat(testGen, opName, rank, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +0100554 pl, const = opName["operands"]
555 shape = testGen.makeShape(rank)
556
557 # Create extra tensors to concat.
558 # Take into account value of pl when getting maximum number of concats
559 num_tensors = testGen.randInt(0, 4)
560 shape_list = []
561 for i in range(pl + const + num_tensors):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100562 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
563 remove = testGen.rng.choice([True, False])
564 wrongShape = shape.copy()
565
566 if remove and len(shape) > 1:
567 wrongShape = wrongShape[1:]
568 else:
569 wrongShape = list(wrongShape)
570 wrongShape.append(testGen.rng.integers(1, 10))
571
572 shape_list.append(wrongShape)
573 else:
574 shape_list.append(shape.copy())
Matthew Haddon818ab902021-07-27 09:12:49 +0100575
576 return shape_list
577
578 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100579 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000580 if error_name in [
581 ErrorIf.AxisSmallerZero,
582 ErrorIf.AxisLargerRank,
583 ErrorIf.ConcatInputRankMismatch,
584 ]:
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100585 return shapeList
586
Matthew Haddon818ab902021-07-27 09:12:49 +0100587 # Split concat shape along axis to allow for multiple const inputs
588 # without making too many large tensors
Jeremy Johnson960985a2021-10-06 10:58:14 +0100589 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100590 # If axis can't be split we still need to invalidate other dimensions
591 if error_name == ErrorIf.ConcatInputDimMismatch:
592 for shape in shapeList[1:]:
593 # Negative test shapeLists are created individually for each test,
594 # so no need to copy the shape before altering it.
595 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
Matthew Haddon818ab902021-07-27 09:12:49 +0100596 return shapeList
597
Jeremy Johnson960985a2021-10-06 10:58:14 +0100598 # Create copy of shape we are going to split (so we don't alter shapeList)
599 shape = shapeList[0].copy()
600 # Add original shape as first input
Matthew Haddon818ab902021-07-27 09:12:49 +0100601 new_shapeList = [shape.copy()]
602 length_on_axis = shape[axis]
603 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700604 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100605 # Calculate split on axis and remaining value
606 split_shape_val = int(shape[axis] / 2)
607 remaining_length = remaining_length - split_shape_val
608
609 # Append new shape, and set remaining shape
610 shape[axis] = split_shape_val
611 new_shapeList.append(shape.copy())
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100612
613 # invalidate dimensions
614 if error_name == ErrorIf.ConcatInputDimMismatch:
615 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
616 else:
617 shape[axis] = remaining_length
618
Matthew Haddon818ab902021-07-27 09:12:49 +0100619 if i == len(shapeList) - 3:
620 new_shapeList.append(shape.copy())
621
622 return new_shapeList
623
624
Eric Kunzee5e26762020-10-13 16:11:07 -0700625class TosaArgGen:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 """Argument generators create exhaustive or random lists of attributes for
627 operators that take attributes or other parameters.
628
629 The return value is a list of (descriptive_name, [arglist]) tuples where
630 the descriptive_name is appended to the test name and the arglist is expanded
631 as arguments to the operator build function.
632 """
Kevin Cheng550ccc52021-03-03 11:21:43 -0800633
Eric Kunzee5e26762020-10-13 16:11:07 -0700634 def __init__(self):
635 pass
636
637 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100638 def agNone(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800639 """A trivial argument generator for operators that don't take any
640 non-tensor arguments"""
641 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700642
643 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100644 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800645 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700646 axes = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700647 shape = shapeList[0]
648
Matthew Haddond6ce7252021-09-29 15:35:44 +0100649 if error_name == ErrorIf.AxisSmallerZero:
650 small_axis = testGen.rng.integers(-5, 0)
651 axes.append(("axis{}".format(small_axis), [small_axis]))
652 elif error_name == ErrorIf.AxisLargerRank:
653 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
654 axes.append(("axis{}".format(large_axis), [large_axis]))
655 else:
656 for a in range(0, len(shape)):
657 axes.append(("axis{}".format(a), [a]))
658
Eric Kunzee5e26762020-10-13 16:11:07 -0700659 return axes
660
661 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100662 def agConv(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700663 arg_list = []
664
665 ifm_shape = shapeList[0]
666 filter_shape = shapeList[1]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
Les Bell7aa69f42021-09-20 10:44:07 +0100668 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700669
Les Bell7aa69f42021-09-20 10:44:07 +0100670 # Check the rank
671 rank = 5 if opName.startswith("conv3d") else 4
Les Bell0e027d42021-11-09 14:42:14 +0000672 if error_name != ErrorIf.WrongRank:
673 assert len(ifm_shape) == rank
674 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700675
Les Bell7aa69f42021-09-20 10:44:07 +0100676 # kernel rank omits batch and channels
677 k_rank = rank - 2
Les Bell0e027d42021-11-09 14:42:14 +0000678 assert len(k) == k_rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700679
Les Bell7aa69f42021-09-20 10:44:07 +0100680 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000681 # - except for named errors, which use specific invalid value(s)
682 if error_name == ErrorIf.PadSmallerZero:
683 p_vals = [testGen.rng.choice(range(-5, 0))]
684 else:
685 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100686 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000687 if error_name == ErrorIf.StrideSmallerOne:
688 # Can't use stride=0, as it is used to derive output shape, as a divisor
689 s_vals = [testGen.rng.choice(range(-5, 0))]
690 else:
691 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100692 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
Les Bell0e027d42021-11-09 14:42:14 +0000693 if error_name == ErrorIf.DilationSmallerOne:
694 d_vals = [testGen.rng.choice(range(-5, 1))]
695 else:
696 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100697 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700698
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000699 if not error_name and testGen.args.oversize:
Les Bell0e027d42021-11-09 14:42:14 +0000700 # add some oversize argument values
701 if max(ifm_shape) < 64:
702 bigPadding = 9
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000703 paddings.update(
704 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
705 )
Les Bell0e027d42021-11-09 14:42:14 +0000706 bigStride = 8
707 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
708 bigDilation = 7
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000709 dilations.update(
710 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
711 )
Les Bellf414b3c2021-09-06 11:29:46 +0100712
Les Bell0e027d42021-11-09 14:42:14 +0000713 # There are too many parameter combinations, so generate them sparsely,
714 # very sparse for negative tests
715 sparsity_factor = 2 if error_name else 100
716 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
717 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100718 if sparsity < 13:
719 sparsity = 1
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000720 # To get a variety of parameter combinations sparsity should not be a
721 # multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100722 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
723 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000724
Les Bellf414b3c2021-09-06 11:29:46 +0100725 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100726 for s in sorted(list(strides)):
727 for p in sorted(list(paddings)):
728 for d in sorted(list(dilations)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000729 if (
730 n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100731 # padding must not exceed the kernel size ?
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000732 # and p[0] < k[0] and p[1] < k[0]
733 # and p[2] < k[1] and p[3] < k[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100734 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
735 # the padded shape must exceed the kernel size
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000736 and (ifm_shape[1] + p[0] + p[1]) > k[0]
737 and (ifm_shape[2] + p[2] + p[3]) > k[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100738 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
739 # the padded shape must exceed the dilation
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000740 and (ifm_shape[1] + p[0] + p[1]) > d[0]
741 and (ifm_shape[2] + p[2] + p[3]) > d[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100742 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
743 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100744 arg_list.append(
745 (
746 "st{}_pad{}_dilat{}".format(
747 "".join([str(x) for x in s]),
748 "".join([str(x) for x in p]),
749 "".join([str(x) for x in d]),
750 ),
751 [s, p, d],
752 )
753 )
754 n += 1
755
Kevin Cheng1533b852021-09-01 12:51:58 -0700756 return arg_list
757
758 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100759 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700760 arg_list = []
761
762 ifm_shape = shapeList[0]
763 filter_shape = shapeList[1]
764
765 # Must be rank 4
Les Bell0e027d42021-11-09 14:42:14 +0000766 if error_name != ErrorIf.WrongRank:
767 assert len(ifm_shape) == 4
768 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700769
Les Bell7aa69f42021-09-20 10:44:07 +0100770 # Generate comprehensive argument lists
Les Bell0e027d42021-11-09 14:42:14 +0000771 # - except for named errors, which use specific invalid value(s)
772 if error_name == ErrorIf.PadSmallerZero:
773 p_vals = [testGen.rng.choice(range(-5, 0))]
774 else:
775 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100776 paddings = {x for x in itertools.product(*([p_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000777 if error_name == ErrorIf.StrideSmallerOne:
778 # Can't use stride=0, as it is used to derive output shape, as a divisor
779 s_vals = [testGen.rng.choice(range(-5, 0))]
780 else:
781 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100782 strides = {x for x in itertools.product(*([s_vals] * 2))}
Les Bell0e027d42021-11-09 14:42:14 +0000783 if error_name == ErrorIf.DilationSmallerOne:
784 d_vals = [testGen.rng.choice(range(-5, 1))]
785 else:
786 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100787 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700788
Les Bell0e027d42021-11-09 14:42:14 +0000789 if not error_name:
790 # add some oversize argument values
791 if max(ifm_shape) < 64:
792 bigPadding = 9
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 paddings.update(
794 {x for x in itertools.product(*([[0, bigPadding]] * 2))}
795 )
Les Bell0e027d42021-11-09 14:42:14 +0000796 bigStride = 8
797 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
798 bigDilation = 7
799 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700800
Les Bell0e027d42021-11-09 14:42:14 +0000801 # There are too many parameter combinations, so generate them sparsely,
802 # very sparse for negative tests
803 sparsity_factor = 2 if error_name else 100
804 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
805 # If there are only a small number of tests, just select them all
Les Bell7aa69f42021-09-20 10:44:07 +0100806 if sparsity < 13:
807 sparsity = 1
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000808 # To get a variety of parameter combinations sparsity should not be a
809 # multiple of 2, 3 or 5
Les Bell7aa69f42021-09-20 10:44:07 +0100810 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
811 sparsity += 1
Les Bell0e027d42021-11-09 14:42:14 +0000812
Les Bell7aa69f42021-09-20 10:44:07 +0100813 n = 0
814 for s in sorted(list(strides)):
815 for p in sorted(list(paddings)):
816 for d in sorted(list(dilations)):
817 if n % sparsity == 0:
818 # Determine the output shape
819 oh = (
820 ifm_shape[1]
821 - filter_shape[1]
822 - (filter_shape[1] - 1) * (d[0] - 1)
823 + 2 * p[0]
824 ) // s[0] + 1
825 ow = (
826 ifm_shape[2]
827 - filter_shape[2]
828 - (filter_shape[2] - 1) * (d[1] - 1)
829 + 2 * p[1]
830 ) // s[1] + 1
831 os = [ifm_shape[0], oh, ow, filter_shape[0]]
832 arg_list.append(
833 (
834 "st{}_pad{}_dilat{}_os{}".format(
835 "".join([str(x) for x in s]),
836 "".join([str(x) for x in p]),
837 "".join([str(x) for x in d]),
838 "x".join([str(x) for x in os]),
839 ),
840 [s, p, d, os],
841 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800842 )
Les Bell7aa69f42021-09-20 10:44:07 +0100843 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700844
845 return arg_list
846
847 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100848 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700849 arg_list = []
850 rank = len(shapeList[0])
851
Les Bell7ffccce2021-07-28 15:37:02 +0100852 # Exhaustively test combinations of padding on each side of each dimension
853 # - the range of padding values is defined by pad_min and pad_max
854 # - for padding >9, the name format needs to be more distinctive
855 pad_min, pad_max = 0, 1
856 pad_values = [x for x in range(pad_min, pad_max + 1)]
Matthew Haddone807aae2021-10-11 18:12:58 +0100857 if error_name == ErrorIf.PadSmallerZero:
858 pad_values = [x for x in range(-2, 0)]
Les Bell7ffccce2021-07-28 15:37:02 +0100859 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
860 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700861
Kevin Chengfe392ce2021-10-18 21:51:55 +0000862 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
863 pad_const_int = testGen.getRandNumberDType(dtype)
864 pad_const_fp = 0
865 elif dtype == DType.FLOAT:
866 pad_const_int = 0
867 pad_const_fp = testGen.getRandNumberDType(dtype)
868 else:
869 return []
870
Les Bell7ffccce2021-07-28 15:37:02 +0100871 for paddings in shape_pad_values:
872 name = "pad"
873 for r in range(rank):
874 before, after = paddings[r]
875 name = f"{name}{before}{after}"
Kevin Chengfe392ce2021-10-18 21:51:55 +0000876 arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700877
878 return arg_list
879
880 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100881 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700882 arg_list = []
883
884 shape = shapeList[0]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100885 if error_name != ErrorIf.WrongRank:
886 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700887
Les Bell7aa69f42021-09-20 10:44:07 +0100888 # Generate comprehensive argument lists
889 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
890 paddings = {x for x in itertools.product(*([p_vals] * 4))}
891 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
892 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000893 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
Les Bell7aa69f42021-09-20 10:44:07 +0100894 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700895
Jeremy Johnson39f127b2022-01-25 17:51:26 +0000896 if testGen.args.oversize:
897 # add some oversize argument values
898 bigStride = 7
899 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
900 bigKernel = 6
901 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
902 if max(shape) < 64:
903 # padding must be less than the kernel size
904 bigPadding = bigKernel - 1
Jeremy Johnsonae0c1c62022-02-10 17:27:34 +0000905 paddings.update(
906 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
907 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700908
Les Bell0e027d42021-11-09 14:42:14 +0000909 # There are too many parameter combinations, so generate them sparsely,
910 # very sparse for negative tests
911 sparsity_factor = 2 if error_name else 500
912 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
913
Les Bell7aa69f42021-09-20 10:44:07 +0100914 n = 0
915 for s in sorted(list(strides)):
916 for p in sorted(list(paddings)):
917 for k in sorted(list(kernels)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000918 if error_name in [
919 ErrorIf.StrideSmallerOne,
920 ErrorIf.KernelSmallerOne,
921 ErrorIf.PadSmallerZero,
922 ErrorIf.PadLargerEqualKernel,
923 ]:
924 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
925 testGen, error_name, s, p, k
926 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100927 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
928 arg_list.append(
929 (
930 "st{}_kern{}_pad{}".format(
931 "".join([str(x) for x in sNew]),
932 "".join([str(x) for x in kNew]),
933 "".join([str(x) for x in pNew]),
934 ),
935 [sNew, pNew, kNew],
936 )
937 )
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000938 elif (
939 n % sparsity == 0
Les Bell7aa69f42021-09-20 10:44:07 +0100940 # padding must not exceed the kernel size
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000941 and p[0] < k[0]
942 and p[1] < k[0]
943 and p[2] < k[1]
944 and p[3] < k[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100945 # the padded shape must exceed the kernel size
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000946 and (shape[1] + p[0] + p[1]) > k[0]
947 and (shape[2] + p[2] + p[3]) > k[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100948 ):
949 arg_list.append(
950 (
951 "st{}_kern{}_pad{}".format(
952 "".join([str(x) for x in s]),
953 "".join([str(x) for x in k]),
954 "".join([str(x) for x in p]),
955 ),
956 [s, p, k],
957 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800958 )
Les Bell7aa69f42021-09-20 10:44:07 +0100959 n += 1
960
Eric Kunzee5e26762020-10-13 16:11:07 -0700961 return arg_list
962
963 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100964 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700965 arg_list = []
966
967 # Enumerate the output types here
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100968 if error_name == ErrorIf.WrongOutputType:
969 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
970 elif inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800971 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700972 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800973 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700974 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800975 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700976 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800977 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700978 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800979 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100980 elif error_name == ErrorIf.WrongInputType:
981 # Pick some potentially correct output type for incorrect input type
982 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700983 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800984 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700985
986 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800987 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700988
989 return arg_list
990
991 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +0100992 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700993 arg_list = []
994
995 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100996 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000997 if (
998 dtype in [DType.UINT8, DType.INT8]
999 and error_name == ErrorIf.OutputZeroPointNotZero
1000 ):
Matthew Haddonc2025212021-10-08 21:21:05 +01001001 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001002 if (
1003 inDtype == DType.UINT8
1004 and dtype != DType.INT8
1005 and error_name != ErrorIf.WrongOutputType
1006 ):
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001007 # The only output dtype for UINT8 is INT8, skip all other combinations
1008 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001009 if (
1010 inDtype != DType.INT8
1011 and dtype == DType.UINT8
1012 and error_name != ErrorIf.WrongOutputType
1013 ):
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001014 # The only input dtype for UINT8 is INT8, skip all other combinations
1015 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001016 if (
1017 error_name == ErrorIf.WrongOutputType
1018 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype)
1019 ):
Matthew Haddonc2025212021-10-08 21:21:05 +01001020 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001021
Kevin Cheng550ccc52021-03-03 11:21:43 -08001022 for scale32 in [False, True]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001023 if error_name == ErrorIf.ScaleTrue and not scale32:
Matthew Haddonc2025212021-10-08 21:21:05 +01001024 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001025 elif error_name == ErrorIf.ScaleNotTrue and scale32:
Matthew Haddonc2025212021-10-08 21:21:05 +01001026 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -08001027 for double_round in [False, True]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 if error_name == ErrorIf.ScaleNotTrue and not double_round:
Matthew Haddonc2025212021-10-08 21:21:05 +01001029 continue
Kevin Cheng550ccc52021-03-03 11:21:43 -08001030 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001031
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001032 if (
1033 inDtype == DType.INT48
1034 and scale32
1035 and error_name != ErrorIf.ScaleTrue
1036 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 # Illegal condition. Must be scale32=False
1038 continue
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001039 if (
1040 double_round
1041 and not scale32
1042 and error_name != ErrorIf.ScaleNotTrue
1043 ):
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001044 # Illegal condition. ERROR_IF(!scale32 && double_round)
1045 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Kevin Cheng550ccc52021-03-03 11:21:43 -08001047 arg_list.append(
1048 (
1049 "out{}_sc{}_dr{}_pc{}".format(
1050 DTypeNames[dtype],
1051 int(scale32),
1052 int(double_round),
1053 int(per_channel),
1054 ),
1055 [dtype, scale32, double_round, per_channel],
1056 )
1057 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001058
1059 return arg_list
1060
Kevin Chengaee1fac2020-11-11 13:54:06 -08001061 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001062 def agMul(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001063 arg_list = []
1064
1065 if dtype is DType.INT32:
1066 for p in range(testGen.args.num_rand_permutations):
1067
1068 shift = testGen.randInt(0, 32)
1069
Kevin Cheng550ccc52021-03-03 11:21:43 -08001070 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001071 else:
Matthew Haddon43e37192021-07-09 14:13:02 +01001072 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001073
1074 return arg_list
1075
1076 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001077 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
Kevin Chengaee1fac2020-11-11 13:54:06 -08001078 arg_list = []
1079
Kevin Cheng550ccc52021-03-03 11:21:43 -08001080 arg_list.append(("roundTrue", [True]))
1081 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001082
1083 return arg_list
1084
Eric Kunzee5e26762020-10-13 16:11:07 -07001085 # Helper function for reshape. Gets some factors of a larger number.
1086 @staticmethod
1087 def getFactors(val, start=1):
1088 factors = []
1089
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001090 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -07001091 if (val % i) == 0:
1092 factors.append(i)
1093
1094 return factors
1095
1096 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001097 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001098 arg_list = []
1099
1100 origShape = shapeList[0]
1101
1102 totalElements = 1
1103 for s in origShape:
1104 totalElements *= s
1105
1106 # This code is NOT fast. Fortunately, the numbers are fairly small.
1107 factors = TosaArgGen.getFactors(totalElements)
1108
1109 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001110 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001111 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -07001112 continue
1113
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001114 found = True
1115 # escape_counter breaks while loop if it continues on for too long
1116 escape_counter = 0
1117 while found:
1118 newShape = []
1119 # Generate newShape ensuring it isn't a duplicate
1120 remainingElements = totalElements
1121 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +01001122 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001123 # pick rank-1 factors
1124 newShape.append(shuffledFactors[0])
1125 remainingElements = remainingElements // shuffledFactors[0]
1126 shuffledFactors = testGen.rng.permutation(
1127 TosaArgGen.getFactors(remainingElements)
1128 )
1129 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -07001130
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001131 # Toss in a -1 sometimes
1132 minusOne = testGen.randInt(0, newRank * 4)
1133 if minusOne < newRank:
1134 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -07001135
Matthew Haddon2ad047d2021-06-22 16:55:23 +01001136 # Check for duplicates
1137 found = False
1138 for name, other_shape in arg_list:
1139 if other_shape[0] == newShape:
1140 found = True
1141 break
1142
1143 escape_counter += 1
1144 if escape_counter >= 100:
1145 break
1146
1147 if not found:
1148 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001149
1150 return arg_list
1151
Eric Kunzee5e26762020-10-13 16:11:07 -07001152 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001153 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001154 arg_list = []
1155
1156 ifm_shape = shapeList[0]
1157
Matthew Haddone807aae2021-10-11 18:12:58 +01001158 if error_name == ErrorIf.IndexOutsideBounds:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001159 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
Matthew Haddone807aae2021-10-11 18:12:58 +01001160 incorrect_small_index = range(-len(ifm_shape), 0)
1161 permutations = [p for p in itertools.permutations(incorrect_large_index)]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001162 permutations.extend(
1163 [p for p in itertools.permutations(incorrect_small_index)]
1164 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001165 elif error_name == ErrorIf.IndexUsedTwice:
1166 # Create list with a duplicated index
1167 perm_range = list(range(len(ifm_shape)))
1168 index_choice = testGen.rng.choice(range(len(perm_range)))
1169 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1170 permutations = [p for p in itertools.permutations(perm_range)]
1171
Matthew Haddone807aae2021-10-11 18:12:58 +01001172 else:
1173 # Get all permutations
1174 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -07001175
Jeremy Johnsona6185572021-06-21 15:55:35 +01001176 # Limit to possible permutations from shape dimension or argument setting
1177 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001178
Jeremy Johnsona6185572021-06-21 15:55:35 +01001179 # Get random permutation generator that uses all permutations
1180 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001181
Jeremy Johnsona6185572021-06-21 15:55:35 +01001182 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -07001183 arg_list = [
1184 ("perm{}".format(p), [random_permutations[p].tolist()])
1185 for p in range(limit)
1186 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07001187 return arg_list
1188
1189 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001190 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001191 arg_list = []
1192
1193 ifm_shape = shapeList[0]
1194 rank = len(ifm_shape)
1195
1196 for p in range(testGen.args.num_rand_permutations):
Matthew Haddone807aae2021-10-11 18:12:58 +01001197 start = []
Eric Kunzee5e26762020-10-13 16:11:07 -07001198 size = []
1199
Kevin Cheng550ccc52021-03-03 11:21:43 -08001200 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -07001201
1202 for i in range(rank):
1203 if ifm_shape[i] > 1:
Matthew Haddone807aae2021-10-11 18:12:58 +01001204 start.append(testGen.randInt(0, ifm_shape[i]))
1205 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001206
1207 # Invalid slice size?
1208 if size[i] == 0:
1209 valid = False
1210 else:
Matthew Haddone807aae2021-10-11 18:12:58 +01001211 start.append(0)
Eric Kunzee5e26762020-10-13 16:11:07 -07001212 size.append(1)
1213
1214 if valid:
Matthew Haddone807aae2021-10-11 18:12:58 +01001215 # If ERROR_IF test required then incorrect start, size will be returned
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001216 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1217 testGen, error_name, ifm_shape, start, size
1218 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001219 arg_list.append(("perm{}".format(p), [start, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001220 return arg_list
1221
1222 @staticmethod
Matthew Haddon1c00b712021-10-01 15:51:03 +01001223 def agTile(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001224 arg_list = []
1225
1226 ifm_shape = shapeList[0]
1227 rank = len(ifm_shape)
1228
1229 for p in range(testGen.args.num_rand_permutations):
1230
1231 # Pick a few random, but small multiple values
1232 # because otherwise this has a tendency to generate
1233 # enormous tensors
1234 multiples = []
1235 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001236 if ifm_shape[i] > 1000:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001237 # Multiple of 1 if ifm_shape dimension is large to reduce
1238 # tensor size
Matthew Haddon82ad4d62021-08-20 15:02:39 +01001239 multiples.append(1)
1240 elif max(ifm_shape) > 1000:
1241 multiples.append(2)
1242 else:
1243 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001244 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001245
1246 return arg_list
1247
1248 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01001249 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001250 arg_list = []
1251
1252 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +01001253 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001254
1255 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +01001256 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001257 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +01001258 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001259 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +01001260 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +01001261 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +01001262 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001263 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001264 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001265 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +01001266 elif error_name == ErrorIf.WrongInputType:
1267 # If an incorrect input type is used then we set a 'correct'
1268 # output type to avoid other errors
1269 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -07001270 else:
1271 continue
1272
1273 for outputDType in outputDTypeList:
1274 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -07001275 # Randomly generate legal output dimensions and shift
1276 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +01001277 # A output_dim of 1 will cause offset to exceed allowed range
1278 # so minimum value 2 produced below
1279 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001280 while (float(ifm_shape[1]) / float(output_dims[0])) >= 16:
Matthew Haddone86fd342021-09-07 16:12:21 +01001281 output_dims[0] += 1
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001282 while (float(ifm_shape[2]) / float(output_dims[1])) >= 16:
Matthew Haddone86fd342021-09-07 16:12:21 +01001283 output_dims[1] += 1
1284
Kevin Cheng77d0f762020-11-24 10:26:32 -08001285 in_center_h = (ifm_shape[1] - 1) / 2.0
1286 in_center_w = (ifm_shape[2] - 1) / 2.0
1287 out_center_h = (output_dims[0] - 1) / 2.0
1288 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -07001289
Kevin Cheng77d0f762020-11-24 10:26:32 -08001290 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1291 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1292 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1293 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -07001294
Kevin Cheng77d0f762020-11-24 10:26:32 -08001295 if outputDType == DType.FLOAT:
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001296 float_op = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001297 arg_str = (
1298 "mode{}_shift{}_odim{}x{}_out{}"
1299 "_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}"
1300 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001301 shift = 0
1302 stride = [0, 0]
1303 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001304 stride_fp = [fp_stride_y, fp_stride_x]
1305 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +01001306
Kevin Cheng77d0f762020-11-24 10:26:32 -08001307 else:
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001308 float_op = False
1309 arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}"
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 shift = testGen.randInt(1, 12)
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001311 # Now search for a shift value (1 to 11) that will produce
1312 # a valid and predictable resize operation
1313 count = 0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001314 while count < 12:
Kevin Cheng77d0f762020-11-24 10:26:32 -08001315 unit = float(1 << shift)
1316 stride_y = int(round(fp_stride_y * unit))
1317 stride_x = int(round(fp_stride_x * unit))
1318 offset_y = int(round(fp_offset_y * unit))
1319 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001320
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001321 if (
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001322 stride_y <= 0
1323 or stride_x <= 0
1324 or stride_y >= (16 << shift)
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001325 or stride_x >= (16 << shift)
1326 or offset_y >= (16 << shift)
1327 or offset_x >= (16 << shift)
1328 or offset_y <= (-16 << shift)
1329 or offset_x <= (-16 << shift)
1330 ):
1331 # Change the shift value and check again
1332 count += 1
1333 shift = (shift % 11) + 1
1334 continue
1335
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001336 def RESIZE_REQUIRE_CALC(
1337 length_in, length_out, stride, offset, shift
1338 ):
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001339 # Perform the pseudo loop to look for out of bounds
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001340 for pos in range(0, length_out):
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001341 a = pos * stride + offset
1342 ia = a >> shift
1343 ia0 = max(ia, 0)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001344 ia1 = min(ia + 1, length_in - 1)
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001345 if ia0 > ia1:
1346 # Found a problem value
1347 break
1348 return ia0, ia1
1349
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001350 iy0, iy1 = RESIZE_REQUIRE_CALC(
1351 ifm_shape[1], output_dims[0], stride_y, offset_y, shift
1352 )
1353 ix0, ix1 = RESIZE_REQUIRE_CALC(
1354 ifm_shape[2], output_dims[1], stride_x, offset_x, shift
1355 )
Jeremy Johnsonc0b24f02021-10-28 17:12:42 +01001356 if ix0 > ix1 or iy0 > iy1:
1357 # Change the shift value and check again
1358 count += 1
1359 shift = (shift % 11) + 1
1360 continue
1361 break
1362
1363 if count >= 12:
1364 # Couldn't find a good set of values for this test, skip it
1365 continue
1366
Kevin Cheng550ccc52021-03-03 11:21:43 -08001367 stride = [stride_y, stride_x]
1368 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001369
1370 stride_fp = [0.0, 0.0]
1371 offset_fp = [0.0, 0.0]
1372
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001373 # Common for all data types
1374 if error_name is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001375 (
1376 shift,
1377 stride,
1378 stride_fp,
1379 offset,
1380 offset_fp,
1381 outputDTypeNew,
1382 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001383 testGen,
1384 error_name,
1385 mode,
1386 dtype,
1387 shapeList,
1388 outputDType,
1389 shift,
1390 stride,
1391 stride_fp,
1392 offset,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001393 offset_fp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001394 )
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001395 else:
1396 outputDTypeNew = outputDType
1397
1398 arg_list.append(
1399 (
1400 arg_str.format(
1401 "N" if mode == ResizeMode.NEAREST else "B",
1402 shift,
1403 output_dims[0],
1404 output_dims[1],
1405 testGen.typeStr(outputDTypeNew),
1406 stride_fp[0] if float_op else stride[0],
1407 stride_fp[1] if float_op else stride[1],
1408 offset_fp[0] if float_op else offset[0],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001409 offset_fp[1] if float_op else offset[1],
Jeremy Johnson27cf5432021-11-16 11:12:17 +00001410 ),
1411 [
1412 mode,
1413 stride,
1414 offset,
1415 shift,
1416 stride_fp,
1417 offset_fp,
1418 output_dims,
1419 dtype,
1420 outputDTypeNew,
1421 ],
1422 )
1423 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001424
1425 return arg_list
1426
Kevin Chengfe392ce2021-10-18 21:51:55 +00001427 @staticmethod
1428 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1429 arg_list = []
1430
1431 if dtype == DType.INT8:
1432 table = np.int32(
1433 testGen.rng.integers(low=-128, high=128, size=[256])
1434 ).tolist()
1435 else: # INT16
1436 table = np.int32(
1437 testGen.rng.integers(low=-32768, high=32768, size=[513])
1438 ).tolist()
1439
1440 arg_list.append(
1441 (
1442 "",
1443 [table],
1444 )
1445 )
1446 return arg_list
1447
Matthew Haddon1c00b712021-10-01 15:51:03 +01001448 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001449 # CondIf generates the condition values here.
1450 # Convert to tensors in the build function, along with the
1451 # then and else blocks
1452 arg_list = []
1453
1454 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001455 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001456
1457 return arg_list
1458
Matthew Haddon1c00b712021-10-01 15:51:03 +01001459 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001460 # While loop: 0 iterations, 1, more than 1
1461 arg_list = []
1462
1463 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001464 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001465
1466 return arg_list
1467
Matthew Haddone86fd342021-09-07 16:12:21 +01001468
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001469class TosaErrorIfArgGen:
Matthew Haddone86fd342021-09-07 16:12:21 +01001470 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001471 def eiResizeErrorIf(
1472 testGen,
1473 error_name,
1474 mode,
1475 dtype,
1476 shapeList,
1477 outputDType,
1478 shift,
1479 stride,
1480 stride_fp,
1481 offset,
1482 offset_fp,
1483 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01001484
1485 if outputDType == DType.FLOAT:
1486 if error_name == ErrorIf.StrideSmallerEqualZero:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001487 stride_fp = testGen.rng.random(size=[2]) - 2
Matthew Haddone86fd342021-09-07 16:12:21 +01001488 elif error_name == ErrorIf.ShiftNotZero:
1489 shift = testGen.rng.integers(1, 5)
1490 elif error_name == ErrorIf.StrideLargerDimension:
1491 shape = shapeList[0]
1492 transform_height = testGen.rng.choice([False, True])
1493 if transform_height:
1494 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1495 else:
1496 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1497 else:
1498 if error_name == ErrorIf.StrideSmallerEqualZero:
1499 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1500 elif error_name == ErrorIf.ShiftSmallerOne:
1501 shift = testGen.rng.integers(-3, 1)
1502 if shift <= 0:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001503 stride = [
1504 (16 >> -shift) - 1,
1505 (16 >> -shift) - 1,
1506 ] # avoids other ERROR_IF checks
1507 offset = [
1508 (16 >> -shift) - 1,
1509 (16 >> -shift) - 1,
1510 ] # avoids other ERROR_IF checks
Matthew Haddone86fd342021-09-07 16:12:21 +01001511 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001512 stride = [
1513 (16 << shift) - 1,
1514 (16 << shift) - 1,
1515 ] # avoids other ERROR_IF checks
1516 offset = [
1517 (16 << shift) - 1,
1518 (16 << shift) - 1,
1519 ] # avoids other ERROR_IF checks
Matthew Haddone86fd342021-09-07 16:12:21 +01001520 elif error_name == ErrorIf.ShiftLargerEleven:
1521 shift = np.int16(testGen.rng.integers(12, 15))
1522 elif error_name == ErrorIf.StrideLargerDimension:
1523 shape = shapeList[0]
1524 transform_height = testGen.rng.choice([False, True])
1525 if transform_height:
1526 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1527 else:
1528 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1529 elif error_name == ErrorIf.StrideLargerEqualMax:
1530 stride = [(16 << shift) + 1, (16 << shift) + 1]
1531 elif error_name == ErrorIf.OffsetLargerEqualMax:
1532 offset = [(16 << shift) + 1, (16 << shift) + 1]
1533 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1534 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1535
Matthew Haddon848efb42021-09-09 12:30:53 +01001536 if error_name == ErrorIf.WrongOutputType:
1537 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001538 incorrect_types = (
1539 DType.INT4,
1540 DType.INT16,
1541 DType.INT32,
1542 DType.INT48,
1543 DType.FLOAT,
1544 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001545 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001546 incorrect_types = (
1547 DType.INT4,
1548 DType.INT8,
1549 DType.INT32,
1550 DType.INT48,
1551 DType.FLOAT,
1552 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001553 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001554 incorrect_types = (
1555 DType.INT4,
1556 DType.INT8,
1557 DType.INT16,
1558 DType.INT48,
1559 DType.FLOAT,
1560 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001561 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 incorrect_types = (
1563 DType.INT4,
1564 DType.INT8,
1565 DType.INT16,
1566 DType.INT32,
1567 DType.FLOAT,
1568 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001569 elif dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001570 incorrect_types = (
1571 DType.INT4,
1572 DType.INT8,
1573 DType.INT16,
1574 DType.INT32,
1575 DType.INT48,
1576 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001577 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001578
Matthew Haddon848efb42021-09-09 12:30:53 +01001579 return shift, stride, stride_fp, offset, offset_fp, outputDType
1580
1581 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001582 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001583 if (
1584 error_name == ErrorIf.StrideSmallerOne
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001585 # padding must not exceed the kernel size
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001586 and pad[0] < kernel[0]
1587 and pad[1] < kernel[0]
1588 and pad[2] < kernel[1]
1589 and pad[3] < kernel[1]
1590 ):
1591 wrongStride = (
1592 testGen.rng.choice([0, -1, -2, -3]),
1593 testGen.rng.choice([0, -1, -2, -3]),
1594 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001595 return wrongStride, pad, kernel
1596 elif error_name == ErrorIf.PadSmallerZero:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001597 wrongPad = (
1598 testGen.rng.choice([-1, -2, -3]),
1599 testGen.rng.choice([-1, -2, -3]),
1600 testGen.rng.choice([-1, -2, -3]),
1601 testGen.rng.choice([-1, -2, -3]),
1602 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001603 return stride, wrongPad, kernel
1604 elif error_name == ErrorIf.KernelSmallerOne:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001605 wrongKernel = (
1606 testGen.rng.choice([0, -1, -2, -3]),
1607 testGen.rng.choice([0, -1, -2, -3]),
1608 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001609 return stride, pad, wrongKernel
1610 elif error_name == ErrorIf.PadLargerEqualKernel:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001611 wrongPad = (
1612 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
1613 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
1614 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
1615 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
1616 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001617 return stride, wrongPad, kernel
1618 else:
1619 return None, None, None
1620
Matthew Haddonc2025212021-10-08 21:21:05 +01001621 @staticmethod
1622 def eiRescaleWrongOutputType(input_dtype, output_dtype):
1623 if input_dtype == DType.INT8:
1624 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
1625 return True
1626 if input_dtype in [DType.INT16, DType.INT32]:
1627 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1628 return True
1629 elif input_dtype == DType.INT48:
1630 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1631 return True
1632 elif input_dtype == DType.UINT8:
1633 if output_dtype != DType.INT8:
1634 return True
1635 return False
1636
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001637 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001638 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1639 # Mess up input/output tensors for ERROR_IF checks
1640 if error_name == "WrongInputList":
1641 add_input = testGen.rng.choice([True, False])
1642 if add_input:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001643 input_list.append("eiDummyInput")
Matthew Haddon848efb42021-09-09 12:30:53 +01001644 else:
1645 input_list = input_list[:-1]
Les Bell0e027d42021-11-09 14:42:14 +00001646 elif error_name == "WrongOutputList":
Matthew Haddon848efb42021-09-09 12:30:53 +01001647 add_output = testGen.rng.choice([True, False])
1648 if add_output:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001649 output_list.append("eiDummyOutput")
Matthew Haddon848efb42021-09-09 12:30:53 +01001650 else:
1651 output_list = []
1652 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001653
Matthew Haddonc2025212021-10-08 21:21:05 +01001654 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01001655 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001656 """Restrict the dimensions and overall size of a shape to
1657 max_dim and max_items.
1658 """
Matthew Haddon630c17c2021-10-14 15:05:41 +01001659 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
1660 while product(new_shape) > max_items:
1661 new_shape = [max(d - 1, 1) for d in new_shape]
1662 return new_shape
Matthew Haddone807aae2021-10-11 18:12:58 +01001663
1664 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
1665 if error_name == ErrorIf.StartSmallerZero:
1666 newStart = []
1667 for i in range(len(input_shape)):
1668 newStart.append(testGen.rng.choice([-3, -2, -1]))
1669 return newStart, size
1670 elif error_name == ErrorIf.SizeSmallerEqualZero:
1671 newSize = []
1672 for i in range(len(input_shape)):
1673 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
1674 return start, newSize
1675 elif error_name == ErrorIf.StartSizeOutsideBounds:
1676 newStart, newSize = [], []
1677 for i in range(len(input_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001678 newStart.append(input_shape[i] - 1)
Matthew Haddone807aae2021-10-11 18:12:58 +01001679 newSize.append(testGen.rng.choice([2, 3, 4]))
1680 return newStart, newSize
1681 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
1682 remove = testGen.rng.choice([True, False])
1683 if remove:
1684 newStart = start[1:]
1685 newSize = size[1:]
1686 else:
1687 newStart = start
1688 newStart.append(1)
1689 newSize = size
1690 newSize.append(1)
1691 return newStart, newSize
1692 else:
1693 return start, size
1694
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001695 @staticmethod
1696 def eiCastErrorIf(testGen, input_dtype):
1697 if input_dtype in [DType.BOOL, DType.FLOAT]:
1698 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
1699 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
1700 outputDType = [DType.INT48]
1701 else:
1702 assert True, f"input_dtype ({input_dtype}) not supported"
1703 return outputDType
1704
1705
Matthew Haddone86fd342021-09-07 16:12:21 +01001706class TosaErrorValidator:
Matthew Haddon848efb42021-09-09 12:30:53 +01001707 @staticmethod
1708 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
Les Bell729b0352021-11-24 10:28:21 +00001709 """Check ERROR_IF statements are caught and set the expected result.
1710
1711 Args:
1712 serializer: the serializer to set the expected result in
1713 validator_fcns: a sequence of validator functions to verify the result
1714 error_name: the name of the ERROR_IF condition to check for
1715 kwargs: keyword arguments for the validator functions
1716 Returns:
1717 True if the result matches the expected result; otherwise False
1718 """
1719 overall_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001720 for val_fcn in validator_fcns:
1721 val_result = val_fcn(True, **kwargs)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 validator_name = val_result["error_name"]
1723 error_result = val_result["error_result"]
1724 error_reason = val_result["error_reason"]
Matthew Haddon848efb42021-09-09 12:30:53 +01001725
Les Bell0e027d42021-11-09 14:42:14 +00001726 # expect an error IFF the error_name and validator_name match
1727 expected_result = error_result == (error_name == validator_name)
Les Bell729b0352021-11-24 10:28:21 +00001728 overall_result &= expected_result
Les Bell0e027d42021-11-09 14:42:14 +00001729
1730 if expected_result and error_result:
Jeremy Johnson2ec34942021-12-14 16:34:05 +00001731 serializer.setExpectedReturnCode(2, True, desc=error_reason)
Les Bell0e027d42021-11-09 14:42:14 +00001732 elif error_result: # and not expected_result
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001733 print(
1734 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1735 f" Expected: {error_name}, Got: {validator_name}"
1736 )
1737 elif not expected_result: # and not error_result
1738 print(
1739 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
1740 f" Expected: {error_name}"
1741 )
Les Bell0e027d42021-11-09 14:42:14 +00001742
1743 if not expected_result:
1744 for k, v in sorted(kwargs.items()):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001745 if k != "op":
1746 if k.endswith("dtype"):
Les Bell0e027d42021-11-09 14:42:14 +00001747 v = valueToName(DType, v)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001748 print(f" {k} = {v}")
Matthew Haddon848efb42021-09-09 12:30:53 +01001749
Les Bell729b0352021-11-24 10:28:21 +00001750 return overall_result
1751
Matthew Haddon848efb42021-09-09 12:30:53 +01001752 @staticmethod
1753 def evWrongInputType(check=False, **kwargs):
Les Bell0e027d42021-11-09 14:42:14 +00001754 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001755
1756 # Find the unsupported input data types
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001757 op = kwargs["op"]
1758 input_dtypes = op["types"]
1759 allowed_input_dtypes = {
1760 t[0] if isinstance(t, list) else t for t in input_dtypes
1761 }
Les Bell0e027d42021-11-09 14:42:14 +00001762 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
Matthew Haddon848efb42021-09-09 12:30:53 +01001763
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001764 if op["op"] == Op.CLAMP:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001765 wrong_input_dtypes.remove(DType.INT48)
1766
Matthew Haddon848efb42021-09-09 12:30:53 +01001767 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 input_dtype = kwargs["input_dtype"]
Les Bell0e027d42021-11-09 14:42:14 +00001769 if input_dtype not in allowed_input_dtypes:
Matthew Haddon848efb42021-09-09 12:30:53 +01001770 error_result = True
1771
1772 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001773 "error_name": ErrorIf.WrongInputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001774 "error_result": error_result,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 "error_reason": "Input data type not supported for this operator",
1776 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
Matthew Haddon848efb42021-09-09 12:30:53 +01001777 }
1778 return info_dict
1779
1780 @staticmethod
1781 def evWrongOutputType(check=False, **kwargs):
Matthew Haddon848efb42021-09-09 12:30:53 +01001782 error_result = False
Matthew Haddon848efb42021-09-09 12:30:53 +01001783
1784 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001785 input_dtype = kwargs["input_dtype"]
1786 output_dtype = kwargs["output_dtype"]
1787 op = kwargs["op"]
Matthew Haddon848efb42021-09-09 12:30:53 +01001788
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001789 if op["op"] == Op.RESIZE:
1790 mode = kwargs["mode"]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001791 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001792 (
1793 mode == ResizeMode.NEAREST
1794 and input_dtype == DType.INT8
1795 and output_dtype != DType.INT8
1796 )
1797 or (
1798 mode == ResizeMode.NEAREST
1799 and input_dtype == DType.INT16
1800 and output_dtype != DType.INT16
1801 )
1802 or (
1803 mode == ResizeMode.BILINEAR
1804 and input_dtype == DType.INT8
1805 and output_dtype != DType.INT32
1806 )
1807 or (
1808 mode == ResizeMode.BILINEAR
1809 and input_dtype == DType.INT16
1810 and output_dtype != DType.INT48
1811 )
1812 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001813 ):
1814 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001815
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001816 elif op["op"] == Op.RESCALE:
Matthew Haddonc2025212021-10-08 21:21:05 +01001817 if input_dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001818 if output_dtype not in [
1819 DType.UINT8,
1820 DType.INT8,
1821 DType.INT16,
1822 DType.INT32,
1823 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001824 error_result = True
1825 if input_dtype in [DType.INT16, DType.INT32]:
1826 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1827 error_result = True
1828 elif input_dtype == DType.INT48:
1829 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
1830 error_result = True
1831 elif input_dtype == DType.UINT8:
1832 if output_dtype != DType.INT8:
1833 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001834
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001835 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001836 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001837 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
1838 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
1839 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001840 ):
1841 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001842
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001843 elif op["op"] == Op.ARGMAX:
1844 if (
1845 input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
1846 and output_dtype != DType.INT32
1847 ):
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001848 error_result = True
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001849
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001850 elif op["op"] == Op.MUL:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001851 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
1852 error_result = True
1853 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1854 error_result = True
1855
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001856 elif op["op"] == Op.TABLE:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001857 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
1858 error_result = True
1859 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
1860 error_result = True
1861
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001862 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001863 if output_dtype != DType.BOOL:
1864 error_result = True
1865
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001866 elif op["op"] == Op.CAST:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001867 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001868 (
1869 input_dtype == DType.BOOL
1870 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
1871 )
1872 or (
1873 input_dtype == DType.INT8
1874 and output_dtype
1875 not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
1876 )
1877 or (
1878 input_dtype == DType.INT16
1879 and output_dtype
1880 not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
1881 )
1882 or (
1883 input_dtype == DType.INT32
1884 and output_dtype
1885 not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
1886 )
1887 or (
1888 input_dtype == DType.FLOAT
1889 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
1890 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001891 ):
1892 error_result = True
1893
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001894 elif op["op"] in {
1895 Op.CONV2D,
1896 Op.CONV3D,
1897 Op.DEPTHWISE_CONV2D,
1898 Op.TRANSPOSE_CONV2D,
1899 }:
Les Bell0e027d42021-11-09 14:42:14 +00001900 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001901 input_dtype == DType.INT8
1902 and output_dtype != DType.INT32
1903 or input_dtype == DType.INT16
1904 and output_dtype != DType.INT48
1905 or input_dtype == DType.FLOAT
1906 and output_dtype != DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00001907 ):
1908 error_result = True
1909 # invalid input types are ignored, to avoid reporting multiple errors
1910
Matthew Haddoneacff9a2021-09-24 14:42:13 +01001911 else:
1912 if output_dtype != input_dtype:
1913 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001914
1915 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00001916 "error_name": ErrorIf.WrongOutputType,
Matthew Haddon848efb42021-09-09 12:30:53 +01001917 "error_result": error_result,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001918 "error_reason": (
1919 "Output data type not supported for this configuration of operator"
1920 ),
1921 "param_reqs": {"rank": None, "dtype": None, "shape": None},
Matthew Haddon848efb42021-09-09 12:30:53 +01001922 }
1923 return info_dict
1924
1925 @staticmethod
1926 def evWrongRank(check=False, **kwargs):
1927 all_ranks = (1, 2, 3, 4, 5)
1928
1929 # Make a list of incorrect ranks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001930 assert "op" in kwargs
1931 op = kwargs["op"]
1932 rmin, rmax = op["rank"]
Matthew Haddon848efb42021-09-09 12:30:53 +01001933 rank_range = range(rmin, rmax + 1)
1934 incorrect_ranks = list(set(all_ranks) - set(rank_range))
Matthew Haddonc2025212021-10-08 21:21:05 +01001935 # Remove small incorrect ranks to avoid index errors
1936 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
Matthew Haddon848efb42021-09-09 12:30:53 +01001937 # Set minimum incorrect rank to 3 to avoid index error
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001938 if op["op"] in [Op.RESIZE]:
Matthew Haddon848efb42021-09-09 12:30:53 +01001939 incorrect_ranks = [3, 5]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001940 elif op["op"] in [Op.TRANSPOSE]:
Matthew Haddon01c359d2021-10-15 16:30:48 +01001941 incorrect_ranks = [7, 8]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001942 elif op["op"] in [Op.CONV3D]:
Les Bell0e027d42021-11-09 14:42:14 +00001943 incorrect_ranks = [6, 7]
Matthew Haddon848efb42021-09-09 12:30:53 +01001944
1945 error_name = ErrorIf.WrongRank
1946 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1947 error_result = False
1948 error_reason = "Rank not supported for this operator"
1949
1950 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001951 input_shape = kwargs["input_shape"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001952
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001953 if (
1954 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
1955 and len(input_shape) != 4
1956 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01001957 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001958 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001959 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001960 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001961 error_result = True
Matthew Haddonc2025212021-10-08 21:21:05 +01001962 else:
1963 if len(input_shape) not in rank_range:
1964 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001965
1966 info_dict = {
1967 "error_name": error_name,
1968 "error_result": error_result,
1969 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001970 "param_reqs": param_reqs,
Matthew Haddon848efb42021-09-09 12:30:53 +01001971 }
1972 return info_dict
1973
1974 @staticmethod
1975 def evWrongInputList(check=False, **kwargs):
1976 error_name = ErrorIf.WrongInputList
1977 param_reqs = {"rank": None, "dtype": None, "shape": None}
1978 error_result = False
1979 error_reason = "Op input list does not match expected input"
1980
1981 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001982 op = kwargs["op"]
1983 input_list = kwargs["input_list"]
1984 num_operands = kwargs["num_operands"]
1985 if op["op"] in [Op.SCATTER, Op.GATHER]:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001986 # SCATTER/GATHER add an indices input tensor in their build functions
1987 num_operands += 1
Kevin Chengfe392ce2021-10-18 21:51:55 +00001988 if len(input_list) != num_operands:
1989 error_result = True
Matthew Haddon848efb42021-09-09 12:30:53 +01001990
1991 info_dict = {
1992 "error_name": error_name,
1993 "error_result": error_result,
1994 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001995 "param_reqs": param_reqs,
Matthew Haddon848efb42021-09-09 12:30:53 +01001996 }
1997 return info_dict
1998
1999 @staticmethod
2000 def evWrongOutputList(check=False, **kwargs):
2001 error_name = ErrorIf.WrongOutputList
2002 param_reqs = {"rank": None, "dtype": None, "shape": None}
2003 error_result = False
2004 error_reason = "Op output list does not match expected output"
2005
2006 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002007 output_list = kwargs["output_list"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002008 # Note this will be incorrect if an operator returns more than one output
2009 if len(output_list) != 1:
2010 error_result = True
2011
2012 info_dict = {
2013 "error_name": error_name,
2014 "error_result": error_result,
2015 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002016 "param_reqs": param_reqs,
Matthew Haddon848efb42021-09-09 12:30:53 +01002017 }
2018 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01002019
2020 @staticmethod
2021 def evMaxDimExceeded(check=False, **kwargs):
2022 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01002023 param_reqs = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002024 "rank": [4, 4],
Matthew Haddon848efb42021-09-09 12:30:53 +01002025 "dtype": [DType.INT8],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002026 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
2027 }
Matthew Haddone86fd342021-09-07 16:12:21 +01002028 error_result = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002029 error_reason = (
2030 "At least one maximum dimension is greater than or equal to 16384"
2031 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002032
2033 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002034 input_shape = kwargs["input_shape"]
2035 output_shape = kwargs["output_shape"] # Note this is just (OH, OW)
2036 if (
2037 (input_shape[1] >= 16384)
2038 or (input_shape[2] >= 16384)
2039 or (output_shape[0] >= 16384)
2040 or (output_shape[1] >= 16384)
2041 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002042 error_result = True
2043
2044 info_dict = {
2045 "error_name": error_name,
2046 "error_result": error_result,
2047 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002048 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002049 }
2050 return info_dict
2051
2052 @staticmethod
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002053 def evBatchMismatch(check=False, **kwargs):
2054 error_name = ErrorIf.BatchMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002055 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002056 error_result = False
2057 error_reason = "Input batch size not equal to output batch size"
2058
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002059 assert "op" in kwargs
2060 op = kwargs["op"]
2061 rmin, rmax = op["rank"]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002062 rank_range = range(rmin, rmax + 1)
2063
2064 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002065 input_shape = kwargs["input_shape"]
2066 output_shape = kwargs[
2067 "result_tensor"
2068 ].shape # Note this is just (N, OH, OW, C)
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002069
2070 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
2071 error_result = True
2072
2073 info_dict = {
2074 "error_name": error_name,
2075 "error_result": error_result,
2076 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002077 "param_reqs": param_reqs,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002078 }
2079 return info_dict
2080
2081 @staticmethod
2082 def evChannelMismatch(check=False, **kwargs):
2083 error_name = ErrorIf.ChannelMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002084 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002085 error_result = False
2086 error_reason = "Input channel size not equal to output channel size"
2087
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002088 assert "op" in kwargs
2089 op = kwargs["op"]
2090 rmin, rmax = op["rank"]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002091 rank_range = range(rmin, rmax + 1)
2092
2093 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002094 input_shape = kwargs["input_shape"]
2095 output_shape = kwargs[
2096 "result_tensor"
2097 ].shape # Note this is just (N, OH, OW, C)
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002098 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
2099 error_result = True
2100
2101 info_dict = {
2102 "error_name": error_name,
2103 "error_result": error_result,
2104 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002105 "param_reqs": param_reqs,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002106 }
2107 return info_dict
2108
2109 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +01002110 def evStrideSmallerEqualZero(check=False, **kwargs):
2111 error_name = ErrorIf.StrideSmallerEqualZero
2112 param_reqs = {"rank": None, "dtype": None, "shape": None}
2113 error_result = False
2114 error_reason = "Stride value smaller than or equal zero"
2115
2116 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002117 input_dtype = kwargs["input_dtype"]
2118 output_dtype = kwargs["output_dtype"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002119 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002120 stride = kwargs["stride"] # Work around wrong input/output type tests
Matthew Haddon848efb42021-09-09 12:30:53 +01002121 elif output_dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002122 stride = kwargs["stride_fp"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002123 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002124 stride = kwargs[
2125 "stride_fp"
2126 ] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01002127 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002128 stride = kwargs["stride"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002129
2130 if min(stride) <= 0:
2131 error_result = True
2132
2133 info_dict = {
2134 "error_name": error_name,
2135 "error_result": error_result,
2136 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002137 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002138 }
2139 return info_dict
2140
2141 @staticmethod
2142 def evStrideLargerEqualMax(check=False, **kwargs):
2143 error_name = ErrorIf.StrideLargerEqualMax
2144 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2145 error_result = False
2146 error_reason = "Stride value larger than or equal to maximum value"
2147
2148 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002149 shift = kwargs["shift"]
2150 input_dtype = kwargs["input_dtype"]
2151 stride = kwargs["stride"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002152 if input_dtype in [DType.INT8, DType.INT16]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002153 if shift >= 0 and (
2154 stride[0] >= (16 << shift) or stride[1] >= (16 << shift)
2155 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002156 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002157 elif shift < 0 and (
2158 stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)
2159 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002160 error_result = True
2161
2162 info_dict = {
2163 "error_name": error_name,
2164 "error_result": error_result,
2165 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002166 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002167 }
2168 return info_dict
2169
Matthew Haddone86fd342021-09-07 16:12:21 +01002170 @staticmethod
2171 def evStrideLargerDimension(check=False, **kwargs):
2172 error_name = ErrorIf.StrideLargerDimension
2173 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
2174 error_result = False
2175 error_reason = "Stride value larger than or equal to H/W dimension"
2176
2177 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002178 shape = kwargs["input_shape"]
2179 input_dtype = kwargs["input_dtype"]
2180 stride = kwargs["stride_fp"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002181
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002182 if (
2183 input_dtype == DType.FLOAT
2184 and (stride[0] > shape[1])
2185 or (stride[1] > shape[2])
2186 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002187 error_result = True
2188
2189 info_dict = {
2190 "error_name": error_name,
2191 "error_result": error_result,
2192 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002193 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002194 }
2195 return info_dict
2196
Matthew Haddone86fd342021-09-07 16:12:21 +01002197 @staticmethod
2198 def evOffsetSmallerEqualMin(check=False, **kwargs):
2199 error_name = ErrorIf.OffsetSmallerEqualMin
2200 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2201 error_result = False
2202 error_reason = "Offset value smaller than or equal to minimum value"
2203
2204 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002205 shift = kwargs["shift"]
2206 output_dtype = kwargs["output_dtype"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002207 if output_dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002208 offset = kwargs["offset_fp"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002209 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002210 offset = kwargs["offset"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002211
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002212 if shift >= 0 and (
2213 offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)
2214 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002215 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002216 elif shift < 0 and (
2217 offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)
2218 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002219 error_result = True
2220
2221 info_dict = {
2222 "error_name": error_name,
2223 "error_result": error_result,
2224 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002225 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002226 }
2227 return info_dict
2228
2229 @staticmethod
2230 def evOffsetLargerEqualMax(check=False, **kwargs):
2231 error_name = ErrorIf.OffsetLargerEqualMax
2232 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2233 error_result = False
2234 error_reason = "Offset value larger than or equal to maximum value"
2235
2236 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002237 shift = kwargs["shift"]
2238 output_dtype = kwargs["output_dtype"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002239 if output_dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002240 offset = kwargs["offset_fp"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002241 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002242 offset = kwargs["offset"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002243
2244 if shift >= 0:
2245 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
2246 error_result = True
2247
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002248 if shift >= 0 and (
2249 offset[0] >= (16 << shift) or offset[1] >= (16 << shift)
2250 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002251 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002252 elif shift < 0 and (
2253 offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)
2254 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002255 error_result = True
2256
2257 info_dict = {
2258 "error_name": error_name,
2259 "error_result": error_result,
2260 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002261 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002262 }
2263 return info_dict
2264
2265 @staticmethod
2266 def evShiftNotZero(check=False, **kwargs):
2267 error_name = ErrorIf.ShiftNotZero
2268 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
2269 error_result = False
2270 error_reason = "Shift value must be zero for float input"
2271
2272 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002273 shift = kwargs["shift"]
2274 input_dtype = kwargs["input_dtype"]
2275 output_dtype = kwargs["output_dtype"]
2276 if (
2277 input_dtype == DType.FLOAT
2278 and output_dtype == DType.FLOAT
2279 and shift != 0
2280 ):
Matthew Haddone86fd342021-09-07 16:12:21 +01002281 error_result = True
2282
2283 info_dict = {
2284 "error_name": error_name,
2285 "error_result": error_result,
2286 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002287 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002288 }
2289 return info_dict
2290
Matthew Haddone86fd342021-09-07 16:12:21 +01002291 @staticmethod
2292 def evShiftSmallerOne(check=False, **kwargs):
2293 error_name = ErrorIf.ShiftSmallerOne
2294 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2295 error_result = False
2296 error_reason = "Shift value smaller than one"
2297
2298 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002299 shift = kwargs["shift"]
2300 input_dtype = kwargs["input_dtype"]
2301 output_dtype = kwargs["output_dtype"]
Matthew Haddon848efb42021-09-09 12:30:53 +01002302 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01002303 error_result = True
2304
2305 info_dict = {
2306 "error_name": error_name,
2307 "error_result": error_result,
2308 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002309 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002310 }
2311 return info_dict
2312
2313 @staticmethod
2314 def evShiftLargerEleven(check=False, **kwargs):
2315 error_name = ErrorIf.ShiftLargerEleven
2316 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
2317 error_result = False
2318 error_reason = "Shift value larger than eleven"
2319
2320 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002321 shift = kwargs["shift"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002322 if shift > 11:
2323 error_result = True
2324
2325 info_dict = {
2326 "error_name": error_name,
2327 "error_result": error_result,
2328 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002329 "param_reqs": param_reqs,
Matthew Haddone86fd342021-09-07 16:12:21 +01002330 }
2331 return info_dict
2332
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002333 @staticmethod
2334 def evRankMismatch(check=False, **kwargs):
2335 error_name = ErrorIf.RankMismatch
2336 param_reqs = {"rank": None, "dtype": None, "shape": None}
2337 error_result = False
2338 error_reason = "Input Rank does not match output rank"
2339
2340 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002341 input1_shape = kwargs["input1"].shape
2342 input2_shape = kwargs["input2"].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002343 # In case of SELECT op
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002344 input3_shape = (
2345 kwargs["input3"].shape if "input3" in kwargs else input2_shape
2346 )
2347 output_shape = kwargs["result_tensor"].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002348 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002349 (len(input1_shape) != len(output_shape))
2350 or (len(input2_shape) != len(output_shape))
2351 or (len(input3_shape) != len(output_shape))
2352 ):
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002353 error_result = True
2354
2355 info_dict = {
2356 "error_name": error_name,
2357 "error_result": error_result,
2358 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002359 "param_reqs": param_reqs,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01002360 }
2361 return info_dict
2362
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002363 @staticmethod
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002364 def evDimensionMismatch(check=False, **kwargs):
2365 error_name = ErrorIf.DimensionMismatch
2366 param_reqs = {"rank": None, "dtype": None, "shape": None}
2367 error_result = False
2368 error_reason = "Input Dimensions do not match output"
2369
2370 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002371 input1_shape = kwargs["input1"].shape
2372 input2_shape = kwargs["input2"].shape
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002373 # In case of SELECT op
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002374 input3_shape = (
2375 kwargs["input3"].shape if "input3" in kwargs else input2_shape
2376 )
2377 output_shape = kwargs["result_tensor"].shape
2378 for i in range(
2379 min(len(input1_shape), len(input2_shape), len(input3_shape))
2380 ):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002381 if (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002382 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
2383 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
2384 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
2385 ):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002386 error_result = True
2387
2388 info_dict = {
2389 "error_name": error_name,
2390 "error_result": error_result,
2391 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002392 "param_reqs": param_reqs,
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00002393 }
2394 return info_dict
2395
2396 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002397 def evInputZeroPointNotZero(check=False, **kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002398 op = kwargs["op"]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002399 error_result = False
Les Bell0e027d42021-11-09 14:42:14 +00002400
2401 # Quantizable types
2402 qTypes = (DType.INT8, DType.UINT8)
2403
2404 # This does not apply to quantizable types
2405 inputDtypes = [
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002406 dtype
2407 for dtype in op["types"]
2408 if (isinstance(dtype, list) and dtype[0] not in qTypes)
2409 or (not isinstance(dtype, list) and dtype not in qTypes)
Les Bell0e027d42021-11-09 14:42:14 +00002410 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002411
2412 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002413 input_dtype = kwargs["input_dtype"]
2414 if isinstance(kwargs["qinfo"], tuple):
2415 qinfo = kwargs["qinfo"]
Matthew Haddonc2025212021-10-08 21:21:05 +01002416 input_zero_point = qinfo[0]
2417 else:
2418 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002419 qinfo = kwargs["qinfo"].ints
Matthew Haddonc2025212021-10-08 21:21:05 +01002420 input_zero_point = qinfo[0][1]
2421
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002422 if op["op"] == Op.MATMUL:
2423 qinfo = kwargs["qinfo"].ints
Les Bell0e027d42021-11-09 14:42:14 +00002424 for dtype, zp in (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002425 (kwargs["input_dtype"], qinfo[0][1]),
2426 (kwargs["input2_dtype"], qinfo[1][1]),
Les Bell0e027d42021-11-09 14:42:14 +00002427 ):
2428 if dtype not in qTypes and zp != 0:
2429 error_result = True
2430 break
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002431 else:
Les Bell0e027d42021-11-09 14:42:14 +00002432 error_result = input_dtype not in qTypes and input_zero_point != 0
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002433
2434 info_dict = {
Les Bell0e027d42021-11-09 14:42:14 +00002435 "error_name": ErrorIf.InputZeroPointNotZero,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002436 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00002437 "error_reason": "Input DType not INT8 and zero point not 0",
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002438 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002439 }
2440 return info_dict
2441
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002442 @staticmethod
2443 def evWeightZeroPointNotZero(check=False, **kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002444 op = kwargs["op"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002445
2446 # exclude inputs with INT8 weights
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002447 inputDtypes = [
2448 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
2449 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002450
2451 error_name = ErrorIf.WeightZeroPointNotZero
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002452 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002453 error_result = False
2454 error_reason = "Weight DType not INT8 and zero point not 0"
2455
2456 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002457 weight_dtype = kwargs["weight_dtype"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002458 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002459 qinfo = kwargs["qinfo"].ints
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002460 weight_zero_point = qinfo[1][1]
2461 if weight_dtype != DType.INT8 and weight_zero_point != 0:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002462 error_result = True
2463
2464 info_dict = {
2465 "error_name": error_name,
2466 "error_result": error_result,
2467 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002468 "param_reqs": param_reqs,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002469 }
2470 return info_dict
2471
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002472 @staticmethod
2473 def evOutputZeroPointNotZero(check=False, **kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002474 op = kwargs["op"]
2475 inputDtypes = op["types"].copy()
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002476 if DType.INT8 in inputDtypes:
2477 inputDtypes.remove(DType.INT8)
2478 if DType.UINT8 in inputDtypes:
2479 inputDtypes.remove(DType.UINT8)
2480
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002481 error_name = ErrorIf.OutputZeroPointNotZero
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002482 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002483 error_result = False
2484 error_reason = "Output DType not INT8 and zero point not 0"
2485
2486 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002487 input_dtype = kwargs["input_dtype"]
2488 output_dtype = kwargs["output_dtype"]
2489 if isinstance(kwargs["qinfo"], tuple):
2490 qinfo = kwargs["qinfo"]
Matthew Haddonc2025212021-10-08 21:21:05 +01002491 output_zero_point = qinfo[1]
2492 else:
2493 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002494 qinfo = kwargs["qinfo"].ints
Matthew Haddonc2025212021-10-08 21:21:05 +01002495 output_zero_point = qinfo[1][1]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002496 if op["op"] == Op.AVG_POOL2D:
Matthew Haddonc2025212021-10-08 21:21:05 +01002497 if input_dtype != DType.INT8 and output_zero_point != 0:
2498 error_result = True
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002499 elif (
2500 output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
2501 ):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002502 error_result = True
2503
2504 info_dict = {
2505 "error_name": error_name,
2506 "error_result": error_result,
2507 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002508 "param_reqs": param_reqs,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002509 }
2510 return info_dict
2511
Matthew Haddond6ce7252021-09-29 15:35:44 +01002512 @staticmethod
2513 def evAxisSmallerZero(check=False, **kwargs):
2514 error_name = ErrorIf.AxisSmallerZero
2515 param_reqs = {"rank": None, "dtype": None, "shape": None}
2516 error_result = False
2517 error_reason = "Axis smaller than zero"
2518
2519 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002520 axis = kwargs["axis"]
Matthew Haddond6ce7252021-09-29 15:35:44 +01002521 if axis < 0:
2522 error_result = True
2523
2524 info_dict = {
2525 "error_name": error_name,
2526 "error_result": error_result,
2527 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002528 "param_reqs": param_reqs,
Matthew Haddond6ce7252021-09-29 15:35:44 +01002529 }
2530 return info_dict
2531
Matthew Haddond6ce7252021-09-29 15:35:44 +01002532 @staticmethod
2533 def evAxisLargerRank(check=False, **kwargs):
2534 error_name = ErrorIf.AxisLargerRank
2535 param_reqs = {"rank": None, "dtype": None, "shape": None}
2536 error_result = False
2537 error_reason = "Axis larger than rank"
2538
2539 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002540 axis = kwargs["axis"]
2541 shape = kwargs["input_shape"]
Matthew Haddond6ce7252021-09-29 15:35:44 +01002542 if axis > len(shape):
2543 error_result = True
2544
2545 info_dict = {
2546 "error_name": error_name,
2547 "error_result": error_result,
2548 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002549 "param_reqs": param_reqs,
Matthew Haddond6ce7252021-09-29 15:35:44 +01002550 }
2551 return info_dict
2552
Matthew Haddond6ce7252021-09-29 15:35:44 +01002553 @staticmethod
2554 def evShapeOfAxisNotOne(check=False, **kwargs):
2555 error_name = ErrorIf.ShapeOfAxisNotOne
2556 param_reqs = {"rank": None, "dtype": None, "shape": None}
2557 error_result = False
2558 error_reason = "shape[axis] is not equal to 1"
2559
2560 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002561 axis = kwargs["axis"]
2562 shape = kwargs["output_shape"]
Matthew Haddond6ce7252021-09-29 15:35:44 +01002563 if (0 <= axis < len(shape)) and shape[axis] != 1:
2564 error_result = True
2565
2566 info_dict = {
2567 "error_name": error_name,
2568 "error_result": error_result,
2569 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002570 "param_reqs": param_reqs,
Matthew Haddond6ce7252021-09-29 15:35:44 +01002571 }
2572 return info_dict
2573
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002574 @staticmethod
2575 def evPadSmallerZero(check=False, **kwargs):
2576 error_name = ErrorIf.PadSmallerZero
2577 param_reqs = {"rank": None, "dtype": None, "shape": None}
2578 error_result = False
2579 error_reason = "At least one pad is smaller than zero"
2580
2581 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002582 op = kwargs["op"]
2583 pad = kwargs["pad"]
2584 if op["op"] == Op.PAD:
Matthew Haddone807aae2021-10-11 18:12:58 +01002585 for padding in pad:
2586 if min(padding) < 0:
2587 error_result = True
2588 else:
2589 if min(pad) < 0:
2590 error_result = True
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002591
2592 info_dict = {
2593 "error_name": error_name,
2594 "error_result": error_result,
2595 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002596 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002597 }
2598 return info_dict
2599
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002600 @staticmethod
2601 def evPadLargerEqualKernel(check=False, **kwargs):
2602 error_name = ErrorIf.PadLargerEqualKernel
2603 param_reqs = {"rank": None, "dtype": None, "shape": None}
2604 error_result = False
2605 error_reason = "At least one pad is larger than kernel dimension"
2606
2607 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002608 pad = kwargs["pad"]
2609 kernel = kwargs["kernel"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002610 if min(pad) > 0 and min(kernel) > 1:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002611 if (
2612 pad[0] >= kernel[0]
2613 or pad[1] >= kernel[0]
2614 or pad[2] >= kernel[1]
2615 or pad[3] >= kernel[1]
2616 ):
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002617 error_result = True
2618
2619 info_dict = {
2620 "error_name": error_name,
2621 "error_result": error_result,
2622 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002623 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002624 }
2625 return info_dict
2626
2627 @staticmethod
2628 def evPoolingOutputShapeMismatch(check=False, **kwargs):
2629 error_name = ErrorIf.PoolingOutputShapeMismatch
2630 param_reqs = {"rank": None, "dtype": None, "shape": None}
2631 error_result = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002632 error_reason = (
2633 "Mismatch between output shape provided and expected output shape"
2634 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002635
2636 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002637 pad = kwargs["pad"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002638 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
2639
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002640 kernel = kwargs["kernel"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002641 kernel_y, kernel_x = kernel[0], kernel[1]
2642
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002643 input_shape = kwargs["input_shape"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002644 IH, IW = input_shape[1], input_shape[2]
2645
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002646 output_shape = kwargs["output_shape"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002647 OH, OW = output_shape[1], output_shape[2]
2648
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002649 stride = kwargs["stride"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002650 stride_y, stride_x = stride[0], stride[1]
2651
2652 # calculate correct height, width dimensions
2653 if stride_x != 0 and stride_y != 0:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002654 y_correct = (
2655 IH + pad_top + pad_bottom + stride_y - kernel_y
2656 ) // stride_y
2657 x_correct = (
2658 IW + pad_left + pad_right + stride_x - kernel_x
2659 ) // stride_x
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002660
2661 # ensure parameters are valid
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002662 params_valid = (
2663 min(kernel) >= 1
2664 and min(stride) >= 1
2665 and min(pad) >= 0
2666 and not (
2667 pad[0] >= kernel[0]
2668 or pad[1] >= kernel[0]
2669 or pad[2] >= kernel[1]
2670 or pad[3] >= kernel[1]
2671 )
2672 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002673
2674 if params_valid and (OH != y_correct or OW != x_correct):
2675 error_result = True
2676
2677 info_dict = {
2678 "error_name": error_name,
2679 "error_result": error_result,
2680 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002681 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002682 }
2683 return info_dict
2684
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002685 @staticmethod
2686 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
2687 error_name = ErrorIf.ArgmaxOutputShapeMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002688 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002689 error_result = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002690 error_reason = (
2691 "Mismatch between output shape provided and expected output shape"
2692 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002693
2694 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002695 output_shape = kwargs["output_shape"]
2696 input_shape = kwargs["input_shape"]
2697 axis = kwargs["axis"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002698
2699 dimension_match = True
2700 axis_shift = 0
2701
2702 # Check that rank is correct before trying to check dimensions
2703 if (len(input_shape) - 1) == len(output_shape):
2704 for i in range(len(input_shape)):
2705 if i == axis:
2706 axis_shift = 1
2707 continue
2708 if input_shape[i] != output_shape[i - axis_shift]:
2709 dimension_match = False
2710
2711 if not dimension_match:
2712 error_result = True
2713
2714 info_dict = {
2715 "error_name": error_name,
2716 "error_result": error_result,
2717 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002718 "param_reqs": param_reqs,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002719 }
2720 return info_dict
2721
2722 @staticmethod
2723 def evArgmaxOutputRankMismatch(check=False, **kwargs):
2724 error_name = ErrorIf.ArgmaxOutputRankMismatch
2725 param_reqs = {"rank": None, "dtype": None, "shape": None}
2726 error_result = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002727 error_reason = (
2728 "Mismatch between output shape provided and expected output shape"
2729 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002730
2731 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002732 output_shape = kwargs["output_shape"]
2733 input_shape = kwargs["input_shape"]
2734 axis = kwargs["axis"]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002735 valid_params = axis >= 0 and axis < len(input_shape)
2736
2737 if valid_params and (len(input_shape) - 1) != len(output_shape):
2738 error_result = True
2739
2740 info_dict = {
2741 "error_name": error_name,
2742 "error_result": error_result,
2743 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002744 "param_reqs": param_reqs,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002745 }
2746 return info_dict
2747
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002748 @staticmethod
2749 def evKernelSmallerOne(check=False, **kwargs):
2750 error_name = ErrorIf.KernelSmallerOne
2751 param_reqs = {"rank": None, "dtype": None, "shape": None}
2752 error_result = False
2753 error_reason = "At least one kernel dimension is smaller than zero"
2754
2755 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002756 kernel = kwargs["kernel"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002757 if min(kernel) < 1:
2758 error_result = True
2759
2760 info_dict = {
2761 "error_name": error_name,
2762 "error_result": error_result,
2763 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002764 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002765 }
2766 return info_dict
2767
2768 @staticmethod
2769 def evStrideSmallerOne(check=False, **kwargs):
2770 error_name = ErrorIf.StrideSmallerOne
2771 param_reqs = {"rank": None, "dtype": None, "shape": None}
2772 error_result = False
2773 error_reason = "At least one stride dimension is smaller than zero"
2774
2775 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002776 stride = kwargs["stride"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002777 if min(stride) < 1:
2778 error_result = True
2779
2780 info_dict = {
2781 "error_name": error_name,
2782 "error_result": error_result,
2783 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002784 "param_reqs": param_reqs,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002785 }
2786 return info_dict
2787
Matthew Haddonc2025212021-10-08 21:21:05 +01002788 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00002789 def evDilationSmallerOne(check=False, **kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002790 error_result = check and min(kwargs["dilation"]) < 1
Les Bell0e027d42021-11-09 14:42:14 +00002791 return {
2792 "error_name": ErrorIf.DilationSmallerOne,
2793 "error_reason": "At least one dilation is smaller than one",
2794 "param_reqs": {"rank": None, "dtype": None, "shape": None},
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002795 "error_result": error_result,
Les Bell0e027d42021-11-09 14:42:14 +00002796 }
2797
2798 @staticmethod
Matthew Haddonc2025212021-10-08 21:21:05 +01002799 def evScaleTrue(check=False, **kwargs):
2800 error_name = ErrorIf.ScaleTrue
2801 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
2802 error_result = False
2803 error_reason = "Scale set to true but input type is INT48"
2804
2805 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002806 input_dtype = kwargs["input_dtype"]
2807 scale32 = kwargs["scale32"]
Matthew Haddonc2025212021-10-08 21:21:05 +01002808 if scale32 and input_dtype == DType.INT48:
2809 error_result = True
2810
2811 info_dict = {
2812 "error_name": error_name,
2813 "error_result": error_result,
2814 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002815 "param_reqs": param_reqs,
Matthew Haddonc2025212021-10-08 21:21:05 +01002816 }
2817 return info_dict
2818
2819 @staticmethod
2820 def evScaleNotTrue(check=False, **kwargs):
2821 error_name = ErrorIf.ScaleNotTrue
2822 param_reqs = {"rank": None, "dtype": None, "shape": None}
2823 error_result = False
2824 error_reason = "Scale set to false but double round set to true"
2825
2826 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002827 scale32 = kwargs["scale32"]
2828 double_round = kwargs["double_round"]
Matthew Haddonc2025212021-10-08 21:21:05 +01002829 if not scale32 and double_round:
2830 error_result = True
2831
2832 info_dict = {
2833 "error_name": error_name,
2834 "error_result": error_result,
2835 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002836 "param_reqs": param_reqs,
Matthew Haddonc2025212021-10-08 21:21:05 +01002837 }
2838 return info_dict
2839
Matthew Haddone807aae2021-10-11 18:12:58 +01002840 @staticmethod
2841 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
2842 error_name = ErrorIf.TensorSizeInputOutputMismatch
2843 param_reqs = {"rank": None, "dtype": None, "shape": None}
2844 error_result = False
2845 error_reason = "Input tensor size does not match output tensor size"
2846
2847 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002848 input_shape = kwargs["input_shape"]
2849 output_shape = kwargs["output_shape"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002850 input_size = np.prod(input_shape)
2851 output_size = np.prod(output_shape)
2852 if input_size != output_size:
2853 error_result = True
2854
2855 info_dict = {
2856 "error_name": error_name,
2857 "error_result": error_result,
2858 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002859 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002860 }
2861 return info_dict
2862
2863 @staticmethod
2864 def evStartSmallerZero(check=False, **kwargs):
2865 error_name = ErrorIf.StartSmallerZero
2866 param_reqs = {"rank": None, "dtype": None, "shape": None}
2867 error_result = False
2868 error_reason = "Starting point smaller than zero"
2869
2870 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002871 input_shape = kwargs["input_shape"]
2872 start = kwargs["start"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002873 rank = len(input_shape)
2874 if len(start) == rank:
2875 for index in range(rank):
2876 if start[index] < 0:
2877 error_result = True
2878
2879 info_dict = {
2880 "error_name": error_name,
2881 "error_result": error_result,
2882 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002883 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002884 }
2885 return info_dict
2886
Matthew Haddone807aae2021-10-11 18:12:58 +01002887 @staticmethod
2888 def evSizeSmallerEqualZero(check=False, **kwargs):
2889 error_name = ErrorIf.SizeSmallerEqualZero
2890 param_reqs = {"rank": None, "dtype": None, "shape": None}
2891 error_result = False
2892 error_reason = "Size smaller than or equal to zero"
2893
2894 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002895 input_shape = kwargs["input_shape"]
2896 size = kwargs["size"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002897 rank = len(input_shape)
2898 if len(size) == rank:
2899 for index in range(rank):
2900 if size[index] <= 0:
2901 error_result = True
2902
2903 info_dict = {
2904 "error_name": error_name,
2905 "error_result": error_result,
2906 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002907 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002908 }
2909 return info_dict
2910
Matthew Haddone807aae2021-10-11 18:12:58 +01002911 @staticmethod
2912 def evStartSizeOutsideBounds(check=False, **kwargs):
2913 error_name = ErrorIf.StartSizeOutsideBounds
2914 param_reqs = {"rank": None, "dtype": None, "shape": None}
2915 error_result = False
2916 error_reason = "starting point plus size larger than input dimension"
2917
2918 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002919 input_shape = kwargs["input_shape"]
2920 start = kwargs["start"]
2921 size = kwargs["size"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002922 rank = len(input_shape)
2923 if len(start) == rank and len(size) == rank:
2924 for index in range(rank):
2925 if start[index] + size[index] > input_shape[index]:
2926 error_result = True
2927
2928 info_dict = {
2929 "error_name": error_name,
2930 "error_result": error_result,
2931 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002932 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002933 }
2934 return info_dict
2935
Matthew Haddone807aae2021-10-11 18:12:58 +01002936 @staticmethod
2937 def evSizeOutputShapeMismatch(check=False, **kwargs):
2938 error_name = ErrorIf.SizeOutputShapeMismatch
2939 param_reqs = {"rank": None, "dtype": None, "shape": None}
2940 error_result = False
2941 error_reason = "Size does not match output dimension"
2942
2943 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002944 input_shape = kwargs["input_shape"]
2945 output_shape = kwargs["output_shape"]
2946 size = kwargs["size"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002947 rank = len(input_shape)
2948 if len(size) == rank:
2949 for index in range(rank):
2950 if size[index] != output_shape[index]:
2951 error_result = True
2952
2953 info_dict = {
2954 "error_name": error_name,
2955 "error_result": error_result,
2956 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002957 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002958 }
2959 return info_dict
2960
2961 @staticmethod
2962 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2963 error_name = ErrorIf.InputSizeStartLengthMismatch
2964 param_reqs = {"rank": None, "dtype": None, "shape": None}
2965 error_result = False
2966 error_reason = "rank of input not equal to length of start or size"
2967
2968 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002969 input_shape = kwargs["input_shape"]
2970 start = kwargs["start"]
2971 size = kwargs["size"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002972 rank = len(input_shape)
2973 if rank != len(start) or rank != len(size):
2974 error_result = True
2975
2976 info_dict = {
2977 "error_name": error_name,
2978 "error_result": error_result,
2979 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002980 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01002981 }
2982 return info_dict
2983
2984 @staticmethod
2985 def evIndexOutsideBounds(check=False, **kwargs):
2986 error_name = ErrorIf.IndexOutsideBounds
2987 param_reqs = {"rank": None, "dtype": None, "shape": None}
2988 error_result = False
2989 error_reason = "Index outside of allowed bounds"
2990
2991 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002992 input_shape = kwargs["input_shape"]
2993 perms = kwargs["perms"]
Matthew Haddone807aae2021-10-11 18:12:58 +01002994 rank = len(input_shape)
2995
2996 for index in perms:
2997 if index < 0 or index > rank:
2998 error_result = True
2999
3000 info_dict = {
3001 "error_name": error_name,
3002 "error_result": error_result,
3003 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003004 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01003005 }
3006 return info_dict
3007
3008 @staticmethod
3009 def evIndexUsedTwice(check=False, **kwargs):
3010 error_name = ErrorIf.IndexUsedTwice
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003011 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddone807aae2021-10-11 18:12:58 +01003012 error_result = False
3013 error_reason = "Index used multiple times"
3014
3015 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003016 perms = kwargs["perms"]
Matthew Haddone807aae2021-10-11 18:12:58 +01003017
3018 unique_indices = []
3019 for index in perms:
3020 if index in unique_indices:
3021 error_result = True
3022 else:
3023 unique_indices.append(index)
3024
3025 info_dict = {
3026 "error_name": error_name,
3027 "error_result": error_result,
3028 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003029 "param_reqs": param_reqs,
Matthew Haddone807aae2021-10-11 18:12:58 +01003030 }
3031 return info_dict
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003032
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003033 @staticmethod
3034 def evMaxSmallerMin(check=False, **kwargs):
3035 error_name = ErrorIf.MaxSmallerMin
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003036 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003037 error_result = False
3038 error_reason = "Max value smaller than min value"
3039
3040 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003041 max_val = kwargs["max_val"]
3042 min_val = kwargs["min_val"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003043 if max_val < min_val:
3044 error_result = True
3045
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003046 info_dict = {
3047 "error_name": error_name,
3048 "error_result": error_result,
3049 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003050 "param_reqs": param_reqs,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003051 }
3052 return info_dict
3053
3054 @staticmethod
3055 def evConcatInputRankMismatch(check=False, **kwargs):
3056 error_name = ErrorIf.ConcatInputRankMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003057 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003058 error_result = False
3059 error_reason = "Input ranks are not identical"
3060
3061 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003062 inputs = kwargs["inputs"]
3063 input_shape = kwargs["input_shape"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003064 for input in inputs:
3065 if len(input.shape) != len(input_shape):
3066 error_result = True
3067
3068 info_dict = {
3069 "error_name": error_name,
3070 "error_result": error_result,
3071 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003072 "param_reqs": param_reqs,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003073 }
3074 return info_dict
3075
3076 @staticmethod
3077 def evConcatInputDimMismatch(check=False, **kwargs):
3078 error_name = ErrorIf.ConcatInputDimMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003079 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003080 error_result = False
3081 error_reason = "Input dimensions differ on too many axes"
3082
3083 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003084 inputs = kwargs["inputs"]
3085 input_shape = kwargs["input_shape"]
3086 axis = kwargs["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003087
3088 # Ensure rank is valid before checking dims.
3089 valid_rank = True
3090 for input in inputs:
3091 if len(input.shape) != len(input_shape):
3092 valid_rank = False
3093
3094 if valid_rank:
3095 for input in inputs:
3096 for i, dim in enumerate(input.shape):
3097 if dim != input_shape[i] and axis != i:
3098 error_result = True
3099
3100 info_dict = {
3101 "error_name": error_name,
3102 "error_result": error_result,
3103 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003104 "param_reqs": param_reqs,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003105 }
3106 return info_dict
3107
Matthew Haddon630c17c2021-10-14 15:05:41 +01003108 @staticmethod
Matthew Haddon01c359d2021-10-15 16:30:48 +01003109 def evConcatShapeSumMismatch(check=False, **kwargs):
3110 error_name = ErrorIf.ConcatShapeSumMismatch
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003111 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
Matthew Haddon01c359d2021-10-15 16:30:48 +01003112 error_result = False
3113 error_reason = "Sum of dimensions on axis not equal to output dimension"
3114
3115 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003116 inputs = kwargs["inputs"]
3117 input_shape = kwargs["input_shape"]
3118 output_shape = kwargs["output_shape"]
3119 axis = kwargs["axis"]
Matthew Haddon01c359d2021-10-15 16:30:48 +01003120
3121 # Ensure rank is valid before checking dims.
3122 valid_params = True
3123 for input in inputs:
3124 if len(input.shape) != len(input_shape):
3125 valid_params = False
3126 if axis < 0 or axis > len(input_shape):
3127 valid_params = False
3128
3129 if valid_params:
3130 axis_dim_sum = 0
3131 for input in inputs:
3132 axis_dim_sum += input.shape[axis]
3133
3134 if axis_dim_sum != output_shape[axis]:
3135 error_result = True
3136
Matthew Haddon01c359d2021-10-15 16:30:48 +01003137 info_dict = {
3138 "error_name": error_name,
3139 "error_result": error_result,
3140 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003141 "param_reqs": param_reqs,
Matthew Haddon01c359d2021-10-15 16:30:48 +01003142 }
3143 return info_dict
3144
3145 @staticmethod
Matthew Haddon630c17c2021-10-14 15:05:41 +01003146 def evInputListThenGraphMismatch(check=False, **kwargs):
3147 error_name = ErrorIf.CondIfInputListThenGraphMismatch
3148 param_reqs = {"rank": None, "dtype": None, "shape": None}
3149 error_result = False
3150 error_reason = "Input list shape does not match then-graph shape"
3151
3152 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003153 a = kwargs["a"]
3154 b = kwargs["b"]
3155 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003156 then_block = basicBlocks[1]
3157 then_inputs = then_block.inputs
3158 then_tens = then_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003159 if (a.shape != then_tens[then_inputs[0]].shape) or (
3160 b.shape != then_tens[then_inputs[1]].shape
3161 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003162 error_result = True
3163
3164 info_dict = {
3165 "error_name": error_name,
3166 "error_result": error_result,
3167 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003168 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003169 }
3170 return info_dict
3171
Matthew Haddon630c17c2021-10-14 15:05:41 +01003172 @staticmethod
3173 def evInputListElseGraphMismatch(check=False, **kwargs):
3174 error_name = ErrorIf.CondIfInputListElseGraphMismatch
3175 param_reqs = {"rank": None, "dtype": None, "shape": None}
3176 error_result = False
3177 error_reason = "Input list shape does not match else-graph shape"
3178
3179 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003180 a = kwargs["a"]
3181 b = kwargs["b"]
3182 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003183 else_block = basicBlocks[2]
3184 else_inputs = else_block.inputs
3185 else_tens = else_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003186 if (a.shape != else_tens[else_inputs[0]].shape) or (
3187 b.shape != else_tens[else_inputs[1]].shape
3188 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003189 error_result = True
3190
3191 info_dict = {
3192 "error_name": error_name,
3193 "error_result": error_result,
3194 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003195 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003196 }
3197 return info_dict
3198
Matthew Haddon630c17c2021-10-14 15:05:41 +01003199 @staticmethod
3200 def evOutputListThenGraphMismatch(check=False, **kwargs):
3201 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
3202 param_reqs = {"rank": None, "dtype": None, "shape": None}
3203 error_result = False
3204 error_reason = "Output list shape does not match then-graph shape"
3205
3206 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003207 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003208 cond_block = basicBlocks[0]
3209 cond_outputs = cond_block.outputs
3210 cond_tens = cond_block.tensors
3211 then_block = basicBlocks[1]
3212 then_outputs = then_block.outputs
3213 then_tens = then_block.tensors
3214 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
3215 error_result = True
3216
3217 info_dict = {
3218 "error_name": error_name,
3219 "error_result": error_result,
3220 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003221 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003222 }
3223 return info_dict
3224
Matthew Haddon630c17c2021-10-14 15:05:41 +01003225 @staticmethod
3226 def evOutputListElseGraphMismatch(check=False, **kwargs):
3227 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
3228 param_reqs = {"rank": None, "dtype": None, "shape": None}
3229 error_result = False
3230 error_reason = "Output list shape does not match else-graph shape"
3231
3232 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003233 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003234 cond_block = basicBlocks[0]
3235 cond_outputs = cond_block.outputs
3236 cond_tens = cond_block.tensors
3237 else_block = basicBlocks[2]
3238 else_outputs = else_block.outputs
3239 else_tens = else_block.tensors
3240 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
3241 error_result = True
3242
3243 info_dict = {
3244 "error_name": error_name,
3245 "error_result": error_result,
3246 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003247 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003248 }
3249 return info_dict
3250
Matthew Haddon630c17c2021-10-14 15:05:41 +01003251 @staticmethod
3252 def evInputListOutputListMismatch(check=False, **kwargs):
3253 error_name = ErrorIf.InputListOutputListMismatch
3254 param_reqs = {"rank": None, "dtype": None, "shape": None}
3255 error_result = False
3256 error_reason = "Input list does not match output list"
3257
3258 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003259 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003260 while_block = basicBlocks[0]
3261 while_inputs = while_block.inputs
3262 while_outputs = while_block.outputs
3263 while_tens = while_block.tensors
3264 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
3265 error_result = True
3266
3267 info_dict = {
3268 "error_name": error_name,
3269 "error_result": error_result,
3270 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003271 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003272 }
3273 return info_dict
3274
Matthew Haddon630c17c2021-10-14 15:05:41 +01003275 @staticmethod
3276 def evInputListCondGraphMismatch(check=False, **kwargs):
3277 error_name = ErrorIf.InputListCondGraphMismatch
3278 param_reqs = {"rank": None, "dtype": None, "shape": None}
3279 error_result = False
3280 error_reason = "Input list does not match cond graph"
3281
3282 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003283 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003284 while_block = basicBlocks[0]
3285 while_inputs = while_block.inputs
3286 while_tens = while_block.tensors
3287 cond_block = basicBlocks[1]
3288 cond_inputs = cond_block.inputs
3289 cond_tens = cond_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003290 if (
3291 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
3292 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003293 error_result = True
3294
3295 info_dict = {
3296 "error_name": error_name,
3297 "error_result": error_result,
3298 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003299 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003300 }
3301 return info_dict
3302
Matthew Haddon630c17c2021-10-14 15:05:41 +01003303 @staticmethod
3304 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
3305 error_name = ErrorIf.InputListBodyGraphInputMismatch
3306 param_reqs = {"rank": None, "dtype": None, "shape": None}
3307 error_result = False
3308 error_reason = "Input list does not match body graph input"
3309
3310 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003311 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003312 while_block = basicBlocks[0]
3313 while_inputs = while_block.inputs
3314 while_tens = while_block.tensors
3315 body_block = basicBlocks[2]
3316 body_outputs = body_block.inputs
3317 body_tens = body_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003318 if (
3319 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
3320 ) or (
3321 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
3322 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003323 error_result = True
3324
3325 info_dict = {
3326 "error_name": error_name,
3327 "error_result": error_result,
3328 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003329 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003330 }
3331 return info_dict
3332
Matthew Haddon630c17c2021-10-14 15:05:41 +01003333 @staticmethod
3334 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
3335 error_name = ErrorIf.InputListBodyGraphOutputMismatch
3336 param_reqs = {"rank": None, "dtype": None, "shape": None}
3337 error_result = False
3338 error_reason = "Input list does not match body graph output"
3339
3340 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003341 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003342 while_block = basicBlocks[0]
3343 while_inputs = while_block.inputs
3344 while_tens = while_block.tensors
3345 body_block = basicBlocks[2]
3346 body_outputs = body_block.outputs
3347 body_tens = body_block.tensors
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003348 if (
3349 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
3350 ) or (
3351 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
3352 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01003353 error_result = True
3354 info_dict = {
3355 "error_name": error_name,
3356 "error_result": error_result,
3357 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003358 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003359 }
3360 return info_dict
3361
Matthew Haddon630c17c2021-10-14 15:05:41 +01003362 @staticmethod
3363 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
3364 error_name = ErrorIf.CondGraphOutputNotMatchingBool
3365 param_reqs = {"rank": None, "dtype": None, "shape": None}
3366 error_result = False
3367 error_reason = "Cond graph output is not a match list of booleans"
3368
3369 if check:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003370 basicBlocks = kwargs["basicBlocks"]
Matthew Haddon630c17c2021-10-14 15:05:41 +01003371 cond_block = basicBlocks[1]
3372 cond_outputs = cond_block.outputs
3373 cond_tens = cond_block.tensors
3374 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
3375 error_result = True
3376
3377 info_dict = {
3378 "error_name": error_name,
3379 "error_result": error_result,
3380 "error_reason": error_reason,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003381 "param_reqs": param_reqs,
Matthew Haddon630c17c2021-10-14 15:05:41 +01003382 }
3383 return info_dict
3384
Matthew Haddonb6b59e32021-10-07 17:19:20 +01003385
Matthew Haddonb724efc2021-08-25 16:40:29 +01003386class TosaInvalidValidator:
Matthew Haddonb724efc2021-08-25 16:40:29 +01003387 @staticmethod
3388 def ivWrongDataTypeOrModeResize(**kwargs):
3389 input_dtype = kwargs["input_dtype"]
3390 args = kwargs["args"]
3391 mode = args[0]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003392 output_dtype = args[8]
3393
3394 if mode == ResizeMode.BILINEAR:
3395 # Invalid output data type / Invalid input datatype
3396 return (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003397 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
3398 or not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
3399 or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
3400 or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
Matthew Haddonb724efc2021-08-25 16:40:29 +01003401 )
3402 elif mode == ResizeMode.NEAREST:
3403 # Invalid output data type / Invalid input datatype
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003404 return (input_dtype != output_dtype) or (
3405 input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003406 )
3407 else:
3408 # Invalid resize mode
3409 return True
3410
3411 @staticmethod
3412 def ivBadStride(**kwargs):
3413 input_dtype = kwargs["input_dtype"]
3414 args = kwargs["args"]
3415 stride_x = args[1][0]
3416 stride_y = args[1][1]
3417 stride_fp_x = args[4][0]
3418 stride_fp_y = args[4][1]
3419
3420 if input_dtype == DType.FLOAT:
3421 if stride_fp_x <= 0 or stride_fp_y <= 0:
3422 # Negative or zero stride
3423 return True
3424 else:
3425 if stride_x <= 0 or stride_y <= 0:
3426 # Negative or zero stride
3427 return True
3428 return False
3429
Matthew Haddonb724efc2021-08-25 16:40:29 +01003430 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003431 def ivHeightWidthInvalid(**kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003432 opName = kwargs["opName"]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003433
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003434 inputShapes = kwargs["shapeList"]
Les Bell0e027d42021-11-09 14:42:14 +00003435 input_shape = inputShapes[0]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003437 args = kwargs["args"]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003438 strides = args[0]
3439 padding = args[1]
Les Bell0e027d42021-11-09 14:42:14 +00003440
Matthew Haddonb724efc2021-08-25 16:40:29 +01003441 if opName.endswith("pool2d"):
Les Bell0e027d42021-11-09 14:42:14 +00003442 # avg_pool2d, max_pool2d
3443 kernel_shape = args[2]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003444 h = (
3445 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
3446 ) // strides[0]
3447 w = (
3448 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
3449 ) // strides[1]
Les Bell0e027d42021-11-09 14:42:14 +00003450 # return True if any dimension is < 1
3451 return h < 1 or w < 1
Matthew Haddonb724efc2021-08-25 16:40:29 +01003452
Les Bell0e027d42021-11-09 14:42:14 +00003453 if opName.startswith("transpose_conv2d"):
3454 # transpose_conv2d
3455 dilations = args[2]
3456 output_shape = args[3]
3457 filter_shape = inputShapes[1]
3458 kernel_shape = filter_shape[1:-1]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003459
Les Bell0e027d42021-11-09 14:42:14 +00003460 def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
3461 """Calculate the transpose_conv2d output size for a dimension.
Matthew Haddonb724efc2021-08-25 16:40:29 +01003462
Les Bell0e027d42021-11-09 14:42:14 +00003463 Based on the keras function deconv_output_length, in
3464 https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
Matthew Haddonb724efc2021-08-25 16:40:29 +01003465
Les Bell0e027d42021-11-09 14:42:14 +00003466 Args:
3467 in_size: the input size - int
3468 stride: the stride - int
3469 kernel_size: the kernel size - int
3470 dilation: the kernel dilation - int
3471 out_pad: the output padding - int
3472 in_pad: the input padding - int
3473
3474 Returns:
3475 the output size
3476 """
3477 dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003478 return (
3479 (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
3480 )
Les Bell0e027d42021-11-09 14:42:14 +00003481
3482 for pad_h, pad_w in (
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003483 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
3484 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
3485 (0, 0), # VALID padding
Les Bell0e027d42021-11-09 14:42:14 +00003486 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003487 h = get_out_size(
3488 input_shape[1],
3489 strides[0],
3490 kernel_shape[0],
3491 dilations[0],
3492 padding[0],
3493 pad_h,
3494 )
3495 w = get_out_size(
3496 input_shape[2],
3497 strides[1],
3498 kernel_shape[1],
3499 dilations[1],
3500 padding[1],
3501 pad_w,
3502 )
Les Bell0e027d42021-11-09 14:42:14 +00003503 if output_shape[1] == h and output_shape[2] == w:
3504 return False
3505
3506 # output shape does not match the expected shape for any padding option
Matthew Haddonb724efc2021-08-25 16:40:29 +01003507 return True
Les Bell0e027d42021-11-09 14:42:14 +00003508
3509 if "conv2d" in opName or "conv3d" in opName:
3510 # conv2d, conv3d, depthwise_conv2d
3511 dilations = args[2]
3512 filter_shape = inputShapes[1]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003513 kernel_shape = (
3514 filter_shape[0:2]
3515 if opName.startswith("depthwise_conv2d")
3516 else filter_shape[1:-1]
3517 )
Les Bell0e027d42021-11-09 14:42:14 +00003518
3519 for i in range(len(kernel_shape)):
3520 dim = (
3521 input_shape[i + 1]
3522 - kernel_shape[i]
3523 - (kernel_shape[i] - 1) * (dilations[i] - 1)
3524 + padding[i * 2 + 0]
3525 + padding[i * 2 + 1]
3526 ) // strides[i] + 1
3527 # return True if any dimension is < 1
3528 if dim < 1:
3529 return True
3530 return False
3531
3532 assert False, f"Unrecognized Op: {opName}"
Matthew Haddonb724efc2021-08-25 16:40:29 +01003533
3534 @staticmethod
3535 def ivNonPositiveOutputShape(**kwargs):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003536 args = kwargs["args"]
Matthew Haddonb724efc2021-08-25 16:40:29 +01003537 output_shape = args[3]
3538 if output_shape[1] <= 0 or output_shape[2] <= 0:
3539 # Negative output shape
3540 return True
3541 return False
3542
3543
Eric Kunzee5e26762020-10-13 16:11:07 -07003544class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003545 # Maximum rank of tensor supported by test generator.
3546 TOSA_TENSOR_MAX_RANK = 6
3547
Eric Kunzee5e26762020-10-13 16:11:07 -07003548 def __init__(self, args):
3549 self.args = args
3550 self.basePath = args.output_dir
3551 self.random_seed = args.random_seed
3552 self.ser = None
3553 self.rng = np.random.default_rng(self.random_seed)
3554 self.createDynamicOpLists()
3555 self.initOpListDefaults()
3556 self.quantGen = TosaQuantGen()
3557 # Force makeShape to do a specific starting shape
3558 self.targetted_shape = None
3559
3560 def createSerializer(self, opName, testPath):
3561 self.testPath = os.path.join(opName, testPath)
3562
3563 fullPath = os.path.join(self.basePath, self.testPath)
3564 os.makedirs(fullPath, exist_ok=True)
3565 self.ser = ts.TosaSerializer(fullPath)
3566
3567 def getSerializer(self):
3568 return self.ser
3569
3570 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003571 with open(
3572 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
3573 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07003574 fd.write(self.ser.serialize())
3575
Kevin Cheng550ccc52021-03-03 11:21:43 -08003576 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
3577 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07003578
Matthew Haddon74567092021-07-16 15:38:20 +01003579 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003580 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +01003581 seed = self.random_seed + 1
3582 self.rng = np.random.default_rng(seed)
3583
Eric Kunzee5e26762020-10-13 16:11:07 -07003584 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07003585 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -07003586 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07003587 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003588 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003589 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003590 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003591 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
3592 elif dtype == DType.UINT8:
3593 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003594 elif dtype == DType.INT16:
3595 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
3596 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003597 return np.int32(
3598 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
3599 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003600 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003601 return np.int64(
3602 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
3603 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003604 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003605 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07003606 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003607 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003608
Kevin Cheng989cb052021-04-28 16:29:44 -07003609 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003610 placeholders = []
3611
Kevin Cheng989cb052021-04-28 16:29:44 -07003612 assert len(shape_list) == len(dtype_list)
3613
3614 for idx, shape in enumerate(shape_list):
3615 arr = self.getRandTensor(shape, dtype_list[idx])
3616 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003617
3618 return placeholders
3619
Kevin Cheng989cb052021-04-28 16:29:44 -07003620 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07003621 consts = []
3622
Kevin Cheng989cb052021-04-28 16:29:44 -07003623 assert len(shape_list) == len(dtype_list)
3624
3625 for idx, shape in enumerate(shape_list):
3626 arr = self.getRandTensor(shape, dtype_list[idx])
3627 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07003628
3629 return consts
3630
3631 def makeShape(self, rank):
3632 if self.targetted_shape:
3633 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003634 return np.int32(
3635 self.rng.integers(
3636 low=self.args.tensor_shape_range[0],
3637 high=self.args.tensor_shape_range[1],
3638 size=rank,
3639 )
3640 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003641
3642 def setTargetShape(self, shape):
3643 self.targetted_shape = shape
3644
3645 def randInt(self, low=0, high=256):
3646 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
3647
3648 def getRandNumberDType(self, dtype):
3649 if dtype == DType.FLOAT:
3650 return self.rng.random()
3651 elif dtype == DType.BOOL:
3652 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07003653 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07003654 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07003655 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07003656 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01003657 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07003658 elif dtype == DType.INT16:
3659 low, high = (-32768, 32768)
3660 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003661 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07003662 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003663 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07003664 # Special size
3665 return np.int64(self.rng.integers(low, high, size=1))[0]
3666 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003667 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003668
3669 return np.int32(self.rng.integers(low, high, size=1))[0]
3670
3671 def shapeStr(self, shape):
3672
3673 sStr = []
3674 # Convert to strings
3675 for i in shape:
3676 sStr.append(str(i))
3677
Kevin Cheng550ccc52021-03-03 11:21:43 -08003678 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003679
3680 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07003681 if isinstance(t, list):
3682 assert len(t) >= 2
3683 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07003684 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07003685 if t == DType.BOOL:
3686 return "b"
3687 elif t == DType.INT4:
3688 return "i4"
3689 elif t == DType.INT8:
3690 return "i8"
3691 elif t == DType.UINT8:
3692 return "u8"
3693 elif t == DType.INT16:
3694 return "i16"
3695 elif t == DType.INT32:
3696 return "i32"
3697 elif t == DType.INT48:
3698 return "i48"
3699 elif t == DType.FLOAT:
3700 return "float"
3701 else:
3702 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07003703
3704 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003705 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08003706 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07003707 return 4
3708 elif t == DType.INT8:
3709 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08003710 elif t == DType.UINT8:
3711 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07003712 elif t == DType.INT16:
3713 return 16
3714 elif t == DType.INT32:
3715 return 32
3716 elif t == DType.INT48:
3717 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +01003718 elif t == DType.FLOAT:
3719 return 32
3720 elif t == DType.BOOL:
3721 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003722 else:
Les Bell729b0352021-11-24 10:28:21 +00003723 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -07003724
3725 # Argument generators
3726 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
3727 # Where the string descriptor is used to generate the test name and
3728 # The build_fcn_arg_list is expanded and passed to the operator test
3729 # build function
3730
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003731 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
3732 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
3733
Matthew Haddon848efb42021-09-09 12:30:53 +01003734 # build_placeholder returns an int, ABS/other ops does not
3735 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003736 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
3737 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003738 elif op["op"] == Op.IDENTITY:
3739 self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo)
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003740 return result_tens
3741
3742 # Ensure new output type has correct qinfo
3743 if error_name == ErrorIf.WrongOutputType:
3744 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
3745 qinfo = ts.TosaSerializerQuantInfo()
3746 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003747 TosaQuantGen.getQinfo(self, a.dtype),
3748 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003749 )
3750
3751 # Invalidate Input/Output list for error if checks.
3752 input_list = [a.name]
3753 output_list = [result_tens.name]
3754 pCount, cCount = op["operands"]
3755 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003756 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3757 self, error_name, input_list, output_list
3758 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003759
Les Bell729b0352021-11-24 10:28:21 +00003760 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003761 self.ser,
3762 validator_fcns,
3763 error_name,
3764 op=op,
3765 input_dtype=a.dtype,
3766 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003767 qinfo=qinfo,
3768 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003769 input_list=input_list,
3770 output_list=output_list,
3771 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003772 ):
3773 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003774
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003775 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07003776 return result_tens
3777
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003778 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 result_tens = OutputShaper.binaryBroadcastOp(
3780 self.ser, self.rng, a, b, error_name
3781 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003782
3783 # Invalidate Input/Output list for error if checks.
3784 input_list = [a.name, b.name]
3785 output_list = [result_tens.name]
3786 pCount, cCount = op["operands"]
3787 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003788 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3789 self, error_name, input_list, output_list
3790 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003791
Les Bell729b0352021-11-24 10:28:21 +00003792 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003793 self.ser,
3794 validator_fcns,
3795 error_name,
3796 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003797 input1=a,
3798 input2=b,
3799 input_dtype=a.dtype,
3800 output_dtype=result_tens.dtype,
3801 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003802 input_list=input_list,
3803 output_list=output_list,
3804 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003805 ):
3806 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003807
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003808 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07003809 return result_tens
3810
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003811 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003812 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003813 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07003814 return result_tens
3815
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003816 def build_arithmetic_right_shift(
3817 self, op, a, b, round, validator_fcns=None, error_name=None
3818 ):
3819 result_tens = OutputShaper.binaryBroadcastOp(
3820 self.ser, self.rng, a, b, error_name
3821 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003822
3823 # Invalidate Input/Output list for error if checks.
3824 input_list = [a.name, b.name]
3825 output_list = [result_tens.name]
3826 pCount, cCount = op["operands"]
3827 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003828 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3829 self, error_name, input_list, output_list
3830 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003831
Les Bell729b0352021-11-24 10:28:21 +00003832 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003833 self.ser,
3834 validator_fcns,
3835 error_name,
3836 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003837 input1=a,
3838 input2=b,
3839 input_dtype=a.dtype,
3840 output_dtype=result_tens.dtype,
3841 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003842 input_list=input_list,
3843 output_list=output_list,
3844 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003845 ):
3846 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -08003847
3848 attr = ts.TosaSerializerAttribute()
3849 attr.ArithmeticRightShiftAttribute(round)
3850
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08003852 return result_tens
3853
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003854 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003855 result_tens = OutputShaper.binaryBroadcastOp(
3856 self.ser, self.rng, a, b, error_name
3857 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003858
3859 # Special for multiply:
3860 # Force the result to INT32 for INT types
3861 if a.dtype != DType.FLOAT:
3862 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003863 if error_name == ErrorIf.WrongOutputType:
3864 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
3865 outputDType = self.rng.choice(all_dtypes)
3866 result_tens.setDtype(outputDType)
3867
3868 # Invalidate Input/Output list for error if checks.
3869 input_list = [a.name, b.name]
3870 output_list = [result_tens.name]
3871 pCount, cCount = op["operands"]
3872 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003873 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3874 self, error_name, input_list, output_list
3875 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003876
Les Bell729b0352021-11-24 10:28:21 +00003877 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003878 self.ser,
3879 validator_fcns,
3880 error_name,
3881 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003882 input1=a,
3883 input2=b,
3884 input_dtype=a.dtype,
3885 output_dtype=result_tens.dtype,
3886 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003887 input_list=input_list,
3888 output_list=output_list,
3889 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003890 ):
3891 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07003892
Kevin Chengaee1fac2020-11-11 13:54:06 -08003893 attr = ts.TosaSerializerAttribute()
3894 attr.MulAttribute(shift)
3895
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003896 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003897 return result_tens
3898
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003899 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
3900 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003901
Kevin Chengfe392ce2021-10-18 21:51:55 +00003902 attr = ts.TosaSerializerAttribute()
3903 attr.TableAttribute(table)
3904
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003905 # Invalidate Input/Output list for error if checks.
3906 input_list = [a.name]
3907 output_list = [result_tens.name]
3908 pCount, cCount = op["operands"]
3909 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003910 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3911 self, error_name, input_list, output_list
3912 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003913
Les Bell729b0352021-11-24 10:28:21 +00003914 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003915 self.ser,
3916 validator_fcns,
3917 error_name,
3918 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003919 input_shape=a.shape,
3920 input_dtype=a.dtype,
3921 output_dtype=result_tens.dtype,
3922 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003923 input_list=input_list,
3924 output_list=output_list,
3925 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003926 ):
3927 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003928
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003929 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07003930
3931 return result_tens
3932
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003933 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
3934 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
3935
3936 # Invalidate Input/Output list for error if checks.
3937 input_list = [cond.name, a.name, b.name]
3938 output_list = [result_tens.name]
3939 pCount, cCount = op["operands"]
3940 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003941 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3942 self, error_name, input_list, output_list
3943 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003944
Les Bell729b0352021-11-24 10:28:21 +00003945 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003946 self.ser,
3947 validator_fcns,
3948 error_name,
3949 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003950 input1=cond,
3951 input2=a,
3952 input3=b,
3953 input_shape=a.shape,
3954 input_dtype=a.dtype,
3955 output_dtype=result_tens.dtype,
3956 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003957 input_list=input_list,
3958 output_list=output_list,
3959 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003960 ):
3961 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003962
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003963 self.ser.addOperator(
3964 op["op"],
3965 input_list,
3966 output_list,
3967 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003968 return result_tens
3969
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003970 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003971 result_tens = OutputShaper.binaryComparisonOp(
3972 self.ser, self.rng, a, b, error_name
3973 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003974
3975 # Invalidate Input/Output list for error if checks.
3976 input_list = [a.name, b.name]
3977 output_list = [result_tens.name]
3978 pCount, cCount = op["operands"]
3979 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003980 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
3981 self, error_name, input_list, output_list
3982 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003983
Les Bell729b0352021-11-24 10:28:21 +00003984 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003985 self.ser,
3986 validator_fcns,
3987 error_name,
3988 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003989 input1=a,
3990 input2=b,
3991 input_shape=a.shape,
3992 input_dtype=a.dtype,
3993 output_shape=result_tens.shape,
3994 output_dtype=result_tens.dtype,
3995 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003996 input_list=input_list,
3997 output_list=output_list,
3998 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00003999 ):
4000 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004001
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004002 self.ser.addOperator(
4003 op["op"],
4004 input_list,
4005 output_list,
4006 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004007 return result_tens
4008
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004009 def build_argmax(self, op, a, axis, validator_fcns, error_name):
4010 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
4011
4012 # Invalidate Input/Output list for error if checks.
4013 input_list = [a.name]
4014 output_list = [result_tens.name]
4015 pCount, cCount = op["operands"]
4016 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004017 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4018 self, error_name, input_list, output_list
4019 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004020
Les Bell729b0352021-11-24 10:28:21 +00004021 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004022 self.ser,
4023 validator_fcns,
4024 error_name,
4025 op=op,
4026 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004027 input_shape=a.shape,
4028 input_dtype=a.dtype,
4029 output_shape=result_tens.shape,
4030 output_dtype=result_tens.dtype,
4031 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004032 input_list=input_list,
4033 output_list=output_list,
4034 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004035 ):
4036 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004037
4038 attr = ts.TosaSerializerAttribute()
4039 attr.AxisAttribute(axis)
4040
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004041 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004042 return result_tens
4043
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004044 def build_pool2d(
4045 self,
4046 op,
4047 input,
4048 stride,
4049 pad,
4050 kernel,
4051 validator_fcns=None,
4052 error_name=None,
4053 qinfo=None,
4054 ):
4055 result_tens = OutputShaper.pool2dOp(
4056 self.ser, self.rng, input, kernel, stride, pad, error_name
4057 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004058
4059 # Ensure new output type has correct qinfo
4060 if error_name == ErrorIf.WrongInputType:
4061 if input.dtype not in [DType.INT8, DType.UINT8]:
4062 qinfo = ts.TosaSerializerQuantInfo()
4063 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004064 TosaQuantGen.getQinfo(self, input.dtype),
4065 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004066 )
4067
4068 # Invalidate Input/Output list for error if checks.
4069 input_list = [input.name]
4070 output_list = [result_tens.name]
4071 pCount, cCount = op["operands"]
4072 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004073 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4074 self, error_name, input_list, output_list
4075 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004076
Les Bell729b0352021-11-24 10:28:21 +00004077 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004078 self.ser,
4079 validator_fcns,
4080 error_name,
4081 op=op,
4082 input_shape=input.shape,
4083 input_dtype=input.dtype,
4084 output_shape=result_tens.shape,
4085 output_dtype=result_tens.dtype,
4086 kernel=kernel,
4087 stride=stride,
4088 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004089 qinfo=qinfo,
4090 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004091 input_list=input_list,
4092 output_list=output_list,
4093 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004094 ):
4095 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004096
4097 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07004098 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07004099
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004100 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004101 return result_tens
4102
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004103 def build_conv2d(
4104 self,
4105 op,
4106 ifm,
4107 filter,
4108 bias,
4109 strides,
4110 padding,
4111 dilations,
4112 validator_fcns=None,
4113 error_name=None,
4114 qinfo=None,
4115 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004116 assert len(padding) == 4
4117 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00004118 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
4119 )
4120
4121 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004122 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
4123 DType.INT8,
4124 DType.UINT8,
4125 ):
Les Bell0e027d42021-11-09 14:42:14 +00004126 qinfo = ts.TosaSerializerQuantInfo()
4127 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004128 TosaQuantGen.getQinfo(self, ifm.dtype),
4129 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +00004130 )
4131
4132 # Invalidate Input/Output list for error_if checks.
4133 input_list = [ifm.name, filter.name, bias.name]
4134 output_list = [result_tens.name]
4135 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004136 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4137 self, error_name, input_list, output_list
4138 )
Les Bell0e027d42021-11-09 14:42:14 +00004139
Les Bell729b0352021-11-24 10:28:21 +00004140 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00004141 self.ser,
4142 validator_fcns,
4143 error_name,
4144 op=op,
4145 input_dtype=ifm.dtype,
4146 weight_dtype=filter.dtype,
4147 output_dtype=result_tens.dtype,
4148 qinfo=qinfo,
4149 input_list=input_list,
4150 num_operands=num_operands,
4151 output_list=output_list,
4152 pad=padding,
4153 stride=strides,
4154 dilation=dilations,
4155 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00004156 ):
4157 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004158
4159 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07004160 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07004161
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004162 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004163 return result_tens
4164
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004165 def build_conv3d(
4166 self,
4167 op,
4168 ifm,
4169 filter,
4170 bias,
4171 strides,
4172 padding,
4173 dilations,
4174 validator_fcns=None,
4175 error_name=None,
4176 qinfo=None,
4177 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004178 assert len(padding) == 6
4179 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +00004180 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
4181 )
4182
4183 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004184 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
4185 DType.INT8,
4186 DType.UINT8,
4187 ):
Les Bell0e027d42021-11-09 14:42:14 +00004188 qinfo = ts.TosaSerializerQuantInfo()
4189 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004190 TosaQuantGen.getQinfo(self, ifm.dtype),
4191 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +00004192 )
4193
4194 # Invalidate Input/Output list for error_if checks.
4195 input_list = [ifm.name, filter.name, bias.name]
4196 output_list = [result_tens.name]
4197 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004198 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4199 self, error_name, input_list, output_list
4200 )
Les Bell0e027d42021-11-09 14:42:14 +00004201
Les Bell729b0352021-11-24 10:28:21 +00004202 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00004203 self.ser,
4204 validator_fcns,
4205 error_name,
4206 op=op,
4207 input_dtype=ifm.dtype,
4208 weight_dtype=filter.dtype,
4209 output_dtype=result_tens.dtype,
4210 qinfo=qinfo,
4211 input_list=input_list,
4212 num_operands=num_operands,
4213 output_list=output_list,
4214 pad=padding,
4215 stride=strides,
4216 dilation=dilations,
4217 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00004218 ):
4219 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07004220
4221 attr = ts.TosaSerializerAttribute()
4222 attr.ConvAttribute(padding, strides, dilations)
4223
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004224 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Kevin Cheng1533b852021-09-01 12:51:58 -07004225 return result_tens
4226
Kevin Cheng550ccc52021-03-03 11:21:43 -08004227 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004228 self,
4229 op,
4230 ifm,
4231 filter,
4232 bias,
4233 stride,
4234 outpad,
4235 dilation,
4236 output_shape,
4237 validator_fcns=None,
4238 error_name=None,
4239 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004240 ):
4241 assert len(outpad) == 2
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004242 result_tens = OutputShaper.transposeConv2DOp(
4243 self.ser, self.rng, ifm, output_shape, error_name
4244 )
Les Bell0e027d42021-11-09 14:42:14 +00004245
4246 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004247 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
4248 DType.INT8,
4249 DType.UINT8,
4250 ):
Les Bell0e027d42021-11-09 14:42:14 +00004251 qinfo = ts.TosaSerializerQuantInfo()
4252 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004253 TosaQuantGen.getQinfo(self, ifm.dtype),
4254 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +00004255 )
4256
4257 # Invalidate Input/Output list for error_if checks.
4258 input_list = [ifm.name, filter.name, bias.name]
4259 output_list = [result_tens.name]
4260 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004261 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4262 self, error_name, input_list, output_list
4263 )
Les Bell0e027d42021-11-09 14:42:14 +00004264
Les Bell729b0352021-11-24 10:28:21 +00004265 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00004266 self.ser,
4267 validator_fcns,
4268 error_name,
4269 op=op,
4270 input_dtype=ifm.dtype,
4271 weight_dtype=filter.dtype,
4272 output_dtype=result_tens.dtype,
4273 qinfo=qinfo,
4274 input_list=input_list,
4275 num_operands=num_operands,
4276 output_list=output_list,
4277 pad=outpad,
4278 stride=stride,
4279 dilation=dilation,
4280 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00004281 ):
4282 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004283
4284 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07004285 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004286
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004287 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004288 return result_tens
4289
Kevin Cheng550ccc52021-03-03 11:21:43 -08004290 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004291 self,
4292 op,
4293 ifm,
4294 filter,
4295 bias,
4296 strides,
4297 padding,
4298 dilations,
4299 validator_fcns=None,
4300 error_name=None,
4301 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004302 ):
4303 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +00004304 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
4305 )
4306
4307 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004308 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
4309 DType.INT8,
4310 DType.UINT8,
4311 ):
Les Bell0e027d42021-11-09 14:42:14 +00004312 qinfo = ts.TosaSerializerQuantInfo()
4313 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004314 TosaQuantGen.getQinfo(self, ifm.dtype),
4315 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +00004316 )
4317
4318 # Invalidate Input/Output list for error_if checks.
4319 input_list = [ifm.name, filter.name, bias.name]
4320 output_list = [result_tens.name]
4321 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004322 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4323 self, error_name, input_list, output_list
4324 )
Les Bell0e027d42021-11-09 14:42:14 +00004325
Les Bell729b0352021-11-24 10:28:21 +00004326 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00004327 self.ser,
4328 validator_fcns,
4329 error_name,
4330 op=op,
4331 input_dtype=ifm.dtype,
4332 weight_dtype=filter.dtype,
4333 output_dtype=result_tens.dtype,
4334 qinfo=qinfo,
4335 input_list=input_list,
4336 num_operands=num_operands,
4337 output_list=output_list,
4338 pad=padding,
4339 stride=strides,
4340 dilation=dilations,
4341 input_shape=ifm.shape,
Les Bell729b0352021-11-24 10:28:21 +00004342 ):
4343 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004344
4345 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07004346 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07004347
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004348 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004349 return result_tens
4350
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004351 def build_fully_connected(
4352 self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
4353 ):
4354 result_tens = OutputShaper.fullyConnectedOp(
4355 self.ser, self.rng, ifm, filter, error_name
4356 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004357
4358 # Invalidate Input/Output list for error if checks.
4359 input_list = [ifm.name, filter.name, bias.name]
4360 output_list = [result_tens.name]
4361 pCount, cCount = op["operands"]
4362 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004363 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4364 self, error_name, input_list, output_list
4365 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004366
Les Bell729b0352021-11-24 10:28:21 +00004367 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004368 self.ser,
4369 validator_fcns,
4370 error_name,
4371 op=op,
4372 input_shape=ifm.shape,
4373 input_dtype=ifm.dtype,
4374 weight_dtype=filter.dtype,
4375 output_shape=result_tens.shape,
4376 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004377 qinfo=qinfo,
4378 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004379 input_list=input_list,
4380 output_list=output_list,
4381 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004382 ):
4383 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004384
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004385 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004386 return result_tens
4387
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004388 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
4389 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
4390
4391 # Invalidate Input/Output list for error if checks.
4392 input_list = [a.name, b.name]
4393 output_list = [result_tens.name]
4394 pCount, cCount = op["operands"]
4395 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004396 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4397 self, error_name, input_list, output_list
4398 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004399
Les Bell729b0352021-11-24 10:28:21 +00004400 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004401 self.ser,
4402 validator_fcns,
4403 error_name,
4404 op=op,
4405 input_shape=a.shape,
4406 input_dtype=a.dtype,
4407 input2_shape=b.shape,
4408 input2_dtype=b.dtype,
4409 output_shape=result_tens.shape,
4410 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004411 qinfo=qinfo,
4412 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004413 input_list=input_list,
4414 output_list=output_list,
4415 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004416 ):
4417 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004418
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004419 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07004420 return result_tens
4421
Matthew Haddond6ce7252021-09-29 15:35:44 +01004422 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
4423 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
4424
4425 # Invalidate Input/Output list for error if checks.
4426 input_list = [a.name]
4427 output_list = [result_tens.name]
4428 pCount, cCount = op["operands"]
4429 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004430 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4431 self, error_name, input_list, output_list
4432 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01004433
Les Bell729b0352021-11-24 10:28:21 +00004434 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01004435 self.ser,
4436 validator_fcns,
4437 error_name,
4438 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004439 axis=axis,
4440 input_shape=a.shape,
4441 output_shape=result_tens.shape,
4442 input_dtype=a.dtype,
4443 output_dtype=result_tens.dtype,
4444 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +01004445 input_list=input_list,
4446 output_list=output_list,
4447 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004448 ):
4449 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004450
4451 attr = ts.TosaSerializerAttribute()
4452 attr.AxisAttribute(axis)
4453
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004454 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004455 return result_tens
4456
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004457 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
4458 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004459
Jeremy Johnson18e26662021-07-22 16:15:29 +01004460 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07004461
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004462 if error_name == ErrorIf.MaxSmallerMin:
4463 # Make sure the numbers are different to invoke this error
4464 while v[0] == v[1]:
4465 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
4466 max_val = min(v)
4467 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004468 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004469 max_val = max(v)
4470 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07004471
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004472 # Invalidate Input/Output list for error if checks.
4473 input_list = [a.name]
4474 output_list = [result_tens.name]
4475 pCount, cCount = op["operands"]
4476 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004477 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4478 self, error_name, input_list, output_list
4479 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004480
Les Bell729b0352021-11-24 10:28:21 +00004481 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004482 self.ser,
4483 validator_fcns,
4484 error_name,
4485 op=op,
4486 max_val=max_val,
4487 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004488 input_shape=a.shape,
4489 output_shape=result_tens.shape,
4490 input_dtype=a.dtype,
4491 output_dtype=result_tens.dtype,
4492 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004493 input_list=input_list,
4494 output_list=output_list,
4495 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004496 ):
4497 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004498
4499 attr = ts.TosaSerializerAttribute()
4500 if a.dtype == DType.FLOAT:
4501 attr.ClampAttribute(0, 0, min_val, max_val)
4502 else:
4503 attr.ClampAttribute(min_val, max_val, 0, 0)
4504
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004505 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004506 return result_tens
4507
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004508 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
4509 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004510 attr = ts.TosaSerializerAttribute()
4511
4512 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
4513
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004514 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004515 return result_tens
4516
4517 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004518 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
4519 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004520
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004521 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07004522 return result_tens
4523
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004524 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
4525 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4526
4527 # Invalidate Input/Output list for error if checks.
4528 input_list = [a.name]
4529 output_list = [result_tens.name]
4530 pCount, cCount = op["operands"]
4531 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004532 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4533 self, error_name, input_list, output_list
4534 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004535
Les Bell729b0352021-11-24 10:28:21 +00004536 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004537 self.ser,
4538 validator_fcns,
4539 error_name,
4540 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004541 input_shape=a.shape,
4542 output_shape=result_tens.shape,
4543 input_dtype=a.dtype,
4544 output_dtype=result_tens.dtype,
4545 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004546 input_list=input_list,
4547 output_list=output_list,
4548 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004549 ):
4550 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004551
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004552 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004553 return result_tens
4554
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004555 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
4556 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4557
4558 # Invalidate Input/Output list for error if checks.
4559 input_list = [a.name]
4560 output_list = [result_tens.name]
4561 pCount, cCount = op["operands"]
4562 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004563 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4564 self, error_name, input_list, output_list
4565 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004566
Les Bell729b0352021-11-24 10:28:21 +00004567 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004568 self.ser,
4569 validator_fcns,
4570 error_name,
4571 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004572 input_shape=a.shape,
4573 output_shape=result_tens.shape,
4574 input_dtype=a.dtype,
4575 output_dtype=result_tens.dtype,
4576 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004577 input_list=input_list,
4578 output_list=output_list,
4579 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004580 ):
4581 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004582
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004583 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004584 return result_tens
4585
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004586 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
4587 if error_name != ErrorIf.WrongInputType:
4588 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01004589
4590 # To store variable length list of input tensors we need to store axis along with it
4591 axis = a[-1]
4592 a = a[:-1]
4593
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004594 result_tens = OutputShaper.concatOp(
4595 self.ser, self.rng, axis, *a, error_name=error_name
4596 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004597
Matthew Haddon818ab902021-07-27 09:12:49 +01004598 input_tensor_names = []
4599 for tensor in a:
4600 input_tensor_names.append(tensor.name)
4601
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004602 # Invalidate Input/Output list for error if checks.
4603 input_list = input_tensor_names
4604 output_list = [result_tens.name]
4605 pCount, cCount = op["operands"]
4606 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004607 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4608 self, error_name, input_list, output_list
4609 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004610
Les Bell729b0352021-11-24 10:28:21 +00004611 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004612 self.ser,
4613 validator_fcns,
4614 error_name,
4615 op=op,
4616 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004617 input_shape=a[0].shape,
4618 output_shape=result_tens.shape,
4619 input_dtype=a[0].dtype,
4620 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004621 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004622 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004623 input_list=input_list,
4624 output_list=output_list,
4625 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004626 ):
4627 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004628
4629 attr = ts.TosaSerializerAttribute()
4630 attr.AxisAttribute(axis)
4631
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004632 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01004633 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004634
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004635 def build_pad(
4636 self,
4637 op,
4638 a,
4639 padding,
4640 pad_const_int,
4641 pad_const_float,
4642 validator_fcns=None,
4643 error_name=None,
4644 qinfo=None,
4645 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01004646 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004647
Kevin Chengfe392ce2021-10-18 21:51:55 +00004648 attr = ts.TosaSerializerAttribute()
4649 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07004650
Matthew Haddone807aae2021-10-11 18:12:58 +01004651 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004652 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004653 output_list = [result_tens.name]
4654 pCount, cCount = op["operands"]
4655 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004656 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4657 self, error_name, input_list, output_list
4658 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004659
Les Bell729b0352021-11-24 10:28:21 +00004660 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004661 self.ser,
4662 validator_fcns,
4663 error_name,
4664 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004665 input_shape=a.shape,
4666 output_shape=result_tens.shape,
4667 input_dtype=a.dtype,
4668 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01004669 pad=padding,
4670 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004671 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01004672 input_list=input_list,
4673 output_list=output_list,
4674 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004675 ):
4676 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01004677
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004678 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Matthew Haddone86fd342021-09-07 16:12:21 +01004679 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07004680
Matthew Haddone807aae2021-10-11 18:12:58 +01004681 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004682 result_tens = OutputShaper.reshapeOp(
4683 self.ser, self.rng, a, newShape, error_name
4684 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004685
4686 # Invalidate Input/Output list for error if checks.
4687 input_list = [a.name]
4688 output_list = [result_tens.name]
4689 pCount, cCount = op["operands"]
4690 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004691 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4692 self, error_name, input_list, output_list
4693 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004694
Les Bell729b0352021-11-24 10:28:21 +00004695 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004696 self.ser,
4697 validator_fcns,
4698 error_name,
4699 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004700 input_shape=a.shape,
4701 output_shape=result_tens.shape,
4702 input_dtype=a.dtype,
4703 output_dtype=result_tens.dtype,
4704 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01004705 input_list=input_list,
4706 output_list=output_list,
4707 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004708 ):
4709 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004710
4711 attr = ts.TosaSerializerAttribute()
4712 attr.ReshapeAttribute(newShape)
4713
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004714 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004715 return result_tens
4716
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004717 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
4718 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
4719
4720 # Invalidate Input/Output list for error if checks.
4721 input_list = [a.name]
4722 output_list = [result_tens.name]
4723 pCount, cCount = op["operands"]
4724 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004725 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4726 self, error_name, input_list, output_list
4727 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004728
Les Bell729b0352021-11-24 10:28:21 +00004729 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004730 self.ser,
4731 validator_fcns,
4732 error_name,
4733 op=op,
4734 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004735 input_shape=a.shape,
4736 output_shape=result_tens.shape,
4737 input_dtype=a.dtype,
4738 output_dtype=result_tens.dtype,
4739 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004740 input_list=input_list,
4741 output_list=output_list,
4742 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004743 ):
4744 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004745
4746 attr = ts.TosaSerializerAttribute()
4747 attr.AxisAttribute(axis)
4748
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004749 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004750 return result_tens
4751
Matthew Haddone807aae2021-10-11 18:12:58 +01004752 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
4753 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07004754
Kevin Chengfe392ce2021-10-18 21:51:55 +00004755 attr = ts.TosaSerializerAttribute()
4756 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07004757
Matthew Haddone807aae2021-10-11 18:12:58 +01004758 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00004759 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01004760 output_list = [result_tens.name]
4761 pCount, cCount = op["operands"]
4762 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004763 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4764 self, error_name, input_list, output_list
4765 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004766
Les Bell729b0352021-11-24 10:28:21 +00004767 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004768 self.ser,
4769 validator_fcns,
4770 error_name,
4771 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004772 input_shape=a.shape,
4773 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01004774 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004775 input_dtype=a.dtype,
4776 output_dtype=result_tens.dtype,
4777 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01004778 input_list=input_list,
4779 output_list=output_list,
4780 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004781 ):
4782 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01004783
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004784 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004785 return result_tens
4786
Matthew Haddone807aae2021-10-11 18:12:58 +01004787 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004788 result_tens = OutputShaper.sliceOp(
4789 self.ser, self.rng, a, start, size, error_name
4790 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004791
4792 # Invalidate Input/Output list for error if checks.
4793 input_list = [a.name]
4794 output_list = [result_tens.name]
4795 pCount, cCount = op["operands"]
4796 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004797 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4798 self, error_name, input_list, output_list
4799 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004800
Les Bell729b0352021-11-24 10:28:21 +00004801 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01004802 self.ser,
4803 validator_fcns,
4804 error_name,
4805 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004806 input_shape=a.shape,
4807 output_shape=result_tens.shape,
4808 input_dtype=a.dtype,
4809 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01004810 start=start,
4811 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004812 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01004813 input_list=input_list,
4814 output_list=output_list,
4815 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004816 ):
4817 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004818
4819 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01004820 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07004821
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004822 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004823 return result_tens
4824
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004825 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
4826 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
4827
4828 # Invalidate Input/Output list for error if checks.
4829 input_list = [a.name]
4830 output_list = [result_tens.name]
4831 pCount, cCount = op["operands"]
4832 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004833 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4834 self, error_name, input_list, output_list
4835 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004836
Les Bell729b0352021-11-24 10:28:21 +00004837 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004838 self.ser,
4839 validator_fcns,
4840 error_name,
4841 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004842 input_shape=a.shape,
4843 output_shape=result_tens.shape,
4844 input_dtype=a.dtype,
4845 output_dtype=result_tens.dtype,
4846 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004847 input_list=input_list,
4848 output_list=output_list,
4849 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004850 ):
4851 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07004852
4853 attr = ts.TosaSerializerAttribute()
4854 attr.TileAttribute(multiples)
4855
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004856 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004857 return result_tens
4858
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004859 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004860
4861 # Create a new indicies tensor
4862 # here with data that doesn't exceed the dimensions of the values tensor
4863
Kevin Cheng550ccc52021-03-03 11:21:43 -08004864 K = values.shape[1] # K
4865 W = self.randInt(
4866 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
4867 ) # W
4868 indicies_arr = np.int32(
4869 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
4870 ) # (N, W)
4871 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07004872
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004873 result_tens = OutputShaper.gatherOp(
4874 self.ser, self.rng, values, indicies, error_name
4875 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004876
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004877 # Invalidate Input/Output list for error if checks.
4878 input_list = [values.name, indicies.name]
4879 output_list = [result_tens.name]
4880 pCount, cCount = op["operands"]
4881 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004882 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4883 self, error_name, input_list, output_list
4884 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004885
Les Bell729b0352021-11-24 10:28:21 +00004886 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004887 self.ser,
4888 validator_fcns,
4889 error_name,
4890 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004891 input_shape=values.shape,
4892 output_shape=result_tens.shape,
4893 input_dtype=values.dtype,
4894 output_dtype=result_tens.dtype,
4895 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004896 input_list=input_list,
4897 output_list=output_list,
4898 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004899 ):
4900 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004901
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004902 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07004903
4904 return result_tens
4905
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004906 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08004907
4908 # Create a new indicies tensor
4909 # here with data that doesn't exceed the dimensions of the values_in tensor
4910
Kevin Cheng550ccc52021-03-03 11:21:43 -08004911 K = values_in.shape[1] # K
4912 W = input.shape[1] # W
4913 indicies_arr = np.int32(
4914 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
4915 ) # (N, W)
4916 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004917
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004918 result_tens = OutputShaper.scatterOp(
4919 self.ser, self.rng, values_in, indicies, input, error_name
4920 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08004921
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004922 # Invalidate Input/Output list for error if checks.
4923 input_list = [values_in.name, indicies.name, input.name]
4924 output_list = [result_tens.name]
4925 pCount, cCount = op["operands"]
4926 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004927 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4928 self, error_name, input_list, output_list
4929 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004930
Les Bell729b0352021-11-24 10:28:21 +00004931 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004932 self.ser,
4933 validator_fcns,
4934 error_name,
4935 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004936 input_shape=values_in.shape,
4937 output_shape=result_tens.shape,
4938 input_dtype=values_in.dtype,
4939 output_dtype=result_tens.dtype,
4940 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004941 input_list=input_list,
4942 output_list=output_list,
4943 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00004944 ):
4945 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08004946
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004947 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004948
Kevin Cheng77d0f762020-11-24 10:26:32 -08004949 return result_tens
4950
Kevin Cheng550ccc52021-03-03 11:21:43 -08004951 def build_resize(
4952 self,
4953 op,
4954 input,
4955 mode,
4956 stride,
4957 offset,
4958 shift,
4959 stride_fp,
4960 offset_fp,
4961 output_dims,
4962 input_dtype,
4963 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01004964 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004965 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004966 ):
4967 result_tens = OutputShaper.resizeOp(
4968 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004969 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004970 input,
4971 mode,
4972 stride,
4973 offset,
4974 shift,
4975 stride_fp,
4976 offset_fp,
4977 output_dims,
4978 input_dtype,
4979 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004980 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004981 )
Eric Kunzee5e26762020-10-13 16:11:07 -07004982
Matthew Haddon848efb42021-09-09 12:30:53 +01004983 # Invalidate Input/Output list for error if checks.
4984 input_list = [input.name]
4985 output_list = [result_tens.name]
4986 pCount, cCount = op["operands"]
4987 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004988 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
4989 self, error_name, input_list, output_list
4990 )
Matthew Haddone86fd342021-09-07 16:12:21 +01004991
Les Bell729b0352021-11-24 10:28:21 +00004992 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01004993 self.ser,
4994 validator_fcns,
4995 error_name,
4996 op=op,
4997 mode=mode,
4998 shift=shift,
4999 input_dtype=input_dtype,
5000 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005001 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01005002 output_shape=output_dims,
5003 offset=offset,
5004 offset_fp=offset_fp,
5005 stride=stride,
5006 stride_fp=stride_fp,
5007 input_list=input_list,
5008 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005009 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01005010 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00005011 ):
5012 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01005013
Eric Kunzee5e26762020-10-13 16:11:07 -07005014 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08005015
Kevin Cheng550ccc52021-03-03 11:21:43 -08005016 attr.ResizeAttribute(
5017 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
5018 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005019
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005020 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005021 return result_tens
5022
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005023 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
5024 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
5025 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005026 self.ser.addOperator(
5027 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
5028 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005029 return result_tens
5030
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005031 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07005032 self.ser.addOutputTensor(val)
5033 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07005034
5035 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005036 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005037 result_tens = OutputShaper.typeConversionOp(
5038 self.ser, self.rng, val, out_dtype, error_name
5039 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005040
5041 # Invalidate Input/Output list for error if checks.
5042 input_list = [val.name]
5043 output_list = [result_tens.name]
5044 pCount, cCount = op["operands"]
5045 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005046 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
5047 self, error_name, input_list, output_list
5048 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005049
Les Bell729b0352021-11-24 10:28:21 +00005050 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005051 self.ser,
5052 validator_fcns,
5053 error_name,
5054 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005055 input_shape=val.shape,
5056 output_shape=result_tens.shape,
5057 input_dtype=val.dtype,
5058 output_dtype=result_tens.dtype,
5059 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005060 input_list=input_list,
5061 output_list=output_list,
5062 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00005063 ):
5064 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005065
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005066 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07005067 return result_tens
5068
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005069 def build_rescale(
5070 self,
5071 op,
5072 val,
5073 out_dtype,
5074 scale32,
5075 double_round,
5076 per_channel,
5077 validator_fcns,
5078 error_name,
5079 ):
5080 result_tens = OutputShaper.typeConversionOp(
5081 self.ser, self.rng, val, out_dtype, error_name
5082 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005083
5084 if per_channel:
5085 nc = val.shape[-1]
5086 else:
5087 nc = 1
5088
5089 in_type_width = self.typeWidth(val.dtype)
5090 out_type_width = self.typeWidth(out_dtype)
5091
Kevin Cheng3a478572021-01-22 17:21:02 -08005092 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01005093 input_zp = self.randInt(-128, 128)
5094 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07005095 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01005096 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07005097 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01005098 elif error_name == ErrorIf.InputZeroPointNotZero:
5099 input_zp = self.randInt(-128, 128)
5100 if input_zp == 0:
5101 input_zp = input_zp + self.rng.integers(1, 10)
5102 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005103 else:
5104 input_zp = 0
5105
Kevin Cheng3a478572021-01-22 17:21:02 -08005106 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01005107 output_zp = self.randInt(-128, 128)
5108 out_type_width = out_type_width + 1
5109 elif out_dtype == DType.UINT8:
5110 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07005111 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01005112 elif error_name == ErrorIf.OutputZeroPointNotZero:
5113 output_zp = self.randInt(-128, 128)
5114 if output_zp == 0:
5115 output_zp = output_zp + self.rng.integers(1, 10)
5116 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005117 else:
5118 output_zp = 0
5119
5120 # Calculate scale based on:
5121 # scale = a *(2^output_width)/(2^input_width))
5122
5123 a = np.float32(self.rng.random(size=[nc]))
5124 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
5125
5126 if scale32:
5127 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01005128 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07005129 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
5130 else:
5131 # Cap the scaling at 2^15 - 1 for scale16
5132 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
5133
Kevin Cheng550ccc52021-03-03 11:21:43 -08005134 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07005135
5136 multiplier_arr = np.int32(np.zeros(shape=[nc]))
5137 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00005138 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
5139 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07005140
5141 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005142 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
5143 scale_arr[i], scale32
5144 )
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00005145 min_shift_value_arr[i] = -1 << (shift_arr[i] - 2)
5146 max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005147
Kevin Cheng550ccc52021-03-03 11:21:43 -08005148 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00005149 if scale32 and error_name is None:
5150 # Make sure random values are within apply_scale_32 speicification
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00005151 # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
5152 assert val.placeholderFilename
5153 values = np.load(
5154 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
5155 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00005156 val_adj = np.subtract(values, input_zp, dtype=np.int64)
5157 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
5158 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
5159 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00005160 if not np.all(np.array_equal(values, val_adj)):
5161 # Values changed so overwrite file with new values
5162 np.save(
5163 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
5164 val_adj,
5165 False,
5166 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005167
Matthew Haddonc2025212021-10-08 21:21:05 +01005168 # Invalidate Input/Output list for error if checks.
5169 input_list = [val.name]
5170 output_list = [result_tens.name]
5171 pCount, cCount = op["operands"]
5172 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005173 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
5174 self, error_name, input_list, output_list
5175 )
Matthew Haddonc2025212021-10-08 21:21:05 +01005176
5177 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00005178 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01005179 self.ser,
5180 validator_fcns,
5181 error_name,
5182 op=op,
5183 input_dtype=val.dtype,
5184 output_dtype=out_dtype,
5185 input_shape=val.shape,
5186 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005187 scale32=scale32,
5188 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01005189 input_list=input_list,
5190 output_list=output_list,
5191 result_tensor=result_tens,
5192 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00005193 ):
5194 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01005195
Eric Kunzee5e26762020-10-13 16:11:07 -07005196 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005197 attr.RescaleAttribute(
5198 input_zp,
5199 output_zp,
5200 multiplier_arr,
5201 shift_arr,
5202 scale32,
5203 double_round,
5204 per_channel,
5205 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005206
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005207 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005208 return result_tens
5209
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005210 def build_cond_if_const(
5211 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
5212 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005213 # For cond_if with constants, we're supplied with then/else tensors that we ignore
5214 # (except for the generated shap) and the condition. Build Then/Else blocks
5215 # and fill them with const nodes for the body.
5216
5217 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08005218 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07005219
5220 # Make then/else tensors
5221 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01005222
5223 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005224 if error_name in [
5225 ErrorIf.CondIfOutputListThenGraphMismatch,
5226 ErrorIf.CondIfOutputListElseGraphMismatch,
5227 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01005228 incorrect_shape = deepcopy(then_tens.shape)
5229 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005230 incorrect_shape[i] += (
5231 self.rng.choice([-3, -2, 2, 3])
5232 if incorrect_shape[i] > 3
5233 else self.rng.choice([1, 2, 4])
5234 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01005235 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
5236
Jeremy Johnson18e26662021-07-22 16:15:29 +01005237 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
5238 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07005239
5240 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08005241 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07005242
5243 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08005244 then_block = "THEN_BLOCK"
5245 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07005246 attr = ts.TosaSerializerAttribute()
5247 attr.CondIfAttribute(then_block, else_block)
5248
5249 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005250 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005251
5252 self.ser.startBasicBlock(then_block)
5253 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01005254 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
5255 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
5256 else:
5257 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005258 self.ser.addOutputTensor(then_tens)
5259
5260 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01005261 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
5262 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
5263 else:
5264 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07005265 self.ser.addOutputTensor(else_tens)
5266
Les Bell729b0352021-11-24 10:28:21 +00005267 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01005268 self.ser,
5269 validator_fcns,
5270 error_name,
5271 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005272 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00005273 ):
5274 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01005275
Eric Kunzee5e26762020-10-13 16:11:07 -07005276 return result_tens
5277
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005278 def build_cond_if_binary(
5279 self, op, a, b, cond, validator_fcns=None, error_name=None
5280 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005281 # For cond_if with a binary op in the then/else blocks, take a and b and
5282 # alternately add or subtract them based on the condition
5283
5284 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08005285 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07005286
Kevin Cheng550ccc52021-03-03 11:21:43 -08005287 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005288
5289 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08005290 then_block = "THEN_BLOCK"
5291 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07005292 attr = ts.TosaSerializerAttribute()
5293 attr.CondIfAttribute(then_block, else_block)
5294
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005295 if error_name in [
5296 ErrorIf.CondIfInputListThenGraphMismatch,
5297 ErrorIf.CondIfInputListElseGraphMismatch,
5298 ErrorIf.CondIfOutputListElseGraphMismatch,
5299 ErrorIf.CondIfOutputListThenGraphMismatch,
5300 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01005301 incorrect_shape = a.shape.copy()
5302 for i in range(len(incorrect_shape)):
5303 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
5304 incorrect_block_input = deepcopy(a)
5305 incorrect_block_input.shape = incorrect_shape
5306
Eric Kunzee5e26762020-10-13 16:11:07 -07005307 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08005308 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005309 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08005310 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005311
Les Bell6040b4d2021-10-11 12:50:31 +01005312 if a.dtype in (DType.FLOAT, DType.INT32):
5313 then_op, else_op = Op.ADD, Op.SUB
5314 elif a.dtype in (DType.INT8, DType.INT16):
5315 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
5316 else:
5317 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07005318
Les Bell6040b4d2021-10-11 12:50:31 +01005319 for block, op in ((then_block, then_op), (else_block, else_op)):
5320 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005321 if (
5322 error_name == ErrorIf.CondIfInputListThenGraphMismatch
5323 and block == then_block
5324 ) or (
5325 error_name == ErrorIf.CondIfInputListElseGraphMismatch
5326 and block == else_block
5327 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01005328 self.ser.addInputTensor(incorrect_block_input)
5329 self.ser.addInputTensor(b)
5330 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005331 elif (
5332 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
5333 and block == then_block
5334 ) or (
5335 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
5336 and block == else_block
5337 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01005338 self.ser.addInputTensor(a)
5339 self.ser.addInputTensor(b)
5340 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
5341 else:
5342 self.ser.addInputTensor(a)
5343 self.ser.addInputTensor(b)
5344 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01005345 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07005346
Les Bell729b0352021-11-24 10:28:21 +00005347 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01005348 self.ser,
5349 validator_fcns,
5350 error_name,
5351 op=op,
5352 a=a,
5353 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005354 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00005355 ):
5356 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01005357
Eric Kunzee5e26762020-10-13 16:11:07 -07005358 return result_tens
5359
Matthew Haddon630c17c2021-10-14 15:05:41 +01005360 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005361 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07005362
Kevin Cheng550ccc52021-03-03 11:21:43 -08005363 cond_block = "COND_BLOCK"
5364 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07005365
5366 attr = ts.TosaSerializerAttribute()
5367 attr.WhileLoopAttribute(cond_block, body_block)
5368
5369 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08005370 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005371 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08005372 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07005373
5374 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08005375 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
5376 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01005377 if error_name == ErrorIf.InputListOutputListMismatch:
5378 incorrect_acc = deepcopy(acc)
5379 for i in range(len(incorrect_acc.shape)):
5380 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
5381 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
5382 else:
5383 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005384
5385 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08005386 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005387 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08005388 [iter.name, a.name, acc.name],
5389 [iter_out.name, a_out.name, acc_out.name],
5390 attr,
5391 )
Kevin Chengb227ae52021-09-02 13:43:17 -07005392 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07005393
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005394 if error_name in [
5395 ErrorIf.InputListCondGraphMismatch,
5396 ErrorIf.InputListBodyGraphInputMismatch,
5397 ErrorIf.InputListBodyGraphOutputMismatch,
5398 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01005399 incorrect_iter = deepcopy(iter)
5400 for i in range(len(incorrect_iter.shape)):
5401 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
5402 if len(incorrect_iter.shape) == 0:
5403 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
5404
5405 incorrect_acc = deepcopy(acc)
5406 for i in range(len(incorrect_acc.shape)):
5407 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
5408
Eric Kunzee5e26762020-10-13 16:11:07 -07005409 # COND block (input: iter, output: cond_tens )
5410 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01005411 if error_name == ErrorIf.InputListCondGraphMismatch:
5412 self.ser.addInputTensor(incorrect_iter)
5413 self.ser.addInputTensor(a)
5414 self.ser.addInputTensor(incorrect_acc)
5415 else:
5416 self.ser.addInputTensor(iter)
5417 self.ser.addInputTensor(a)
5418 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005419 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01005420
5421 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005422 cond_tens = self.ser.addOutput(
5423 [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
5424 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01005425 else:
5426 cond_tens = self.ser.addOutput([], DType.BOOL)
5427
Kevin Cheng550ccc52021-03-03 11:21:43 -08005428 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07005429
5430 # BODY block (input: a, acc, iter, output: a, acc, iter)
5431 # Note that local intermediate tensors need to be declared here for the outputs
5432 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01005433 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
5434 self.ser.addInputTensor(incorrect_iter)
5435 self.ser.addInputTensor(a)
5436 self.ser.addInputTensor(incorrect_acc)
5437 else:
5438 self.ser.addInputTensor(iter)
5439 self.ser.addInputTensor(a)
5440 self.ser.addInputTensor(acc)
5441
Kevin Cheng550ccc52021-03-03 11:21:43 -08005442 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01005443
5444 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005445 iter_body_out = self.ser.addIntermediate(
5446 incorrect_iter.shape, incorrect_iter.dtype
5447 )
5448 acc_body_out = self.ser.addIntermediate(
5449 incorrect_acc.shape, incorrect_acc.dtype
5450 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01005451 else:
5452 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
5453 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
5454
Eric Kunzee5e26762020-10-13 16:11:07 -07005455 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
5456 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
5457 self.ser.addOutputTensor(iter_body_out)
5458 self.ser.addOutputTensor(a)
5459 self.ser.addOutputTensor(acc_body_out)
5460
Les Bell729b0352021-11-24 10:28:21 +00005461 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01005462 self.ser,
5463 validator_fcns,
5464 error_name,
5465 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005466 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00005467 ):
5468 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01005469
Eric Kunzee5e26762020-10-13 16:11:07 -07005470 return acc_out
5471
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005472 def create_filter_lists(
5473 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
5474 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005475 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
5476 default_test_rank_range = range(1, 5)
5477 if not shapeFilter:
5478 shapeFilter = [None]
5479
5480 # Calculate the filters based on what is requested and what the operator allows
5481 rmin, rmax = op["rank"]
5482 if rankFilter is not None:
5483 cleanRankFilter = []
5484 # Ensure rankFilter values are allowed by operator
5485 for rank in rankFilter:
5486 if rank >= rmin and rank <= rmax:
5487 cleanRankFilter.append(rank)
5488 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01005489 # Ensure default behaviour is bounded by default range or by operator,
5490 # whichever is the smaller range of ranks.
5491 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005492 cleanRankFilter = (
5493 opRankRange
5494 if len(opRankRange) <= len(default_test_rank_range)
5495 else default_test_rank_range
5496 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005497 else:
5498 cleanRankFilter = range(rmin, rmax + 1)
5499
5500 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005501
Matthew Haddon1c00b712021-10-01 15:51:03 +01005502 if dtypeFilter is not None:
5503 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01005504 # Create list of operator dtypes filtered by requested dtypes
5505 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005506 if dtype in dtypeFilter or (
5507 isinstance(dtype, list) and dtype[0] in dtypeFilter
5508 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005509 cleanDtypeFilter.append(dtype)
5510 else:
5511 cleanDtypeFilter = dtypes
5512
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005513 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01005514 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005515 "shapeFilter": shapeFilter,
5516 "rankFilter": cleanRankFilter,
5517 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01005518 }
5519 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005520 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01005521 if validator is not None:
5522 validator_info = validator(check=False, op=op)
5523 else:
5524 return None
5525
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005526 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005527
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005528 # Set parameters as required
5529 if error_arguments["rank"] is not None:
5530 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005531 else:
5532 rankFilter = cleanRankFilter
5533
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005534 if error_arguments["dtype"] is not None:
5535 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005536 else:
5537 dtypeFilter = cleanDtypeFilter
5538
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005539 if error_arguments["shape"] is not None:
5540 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005541 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005542 shapeFilter = shapeFilter[
5543 :2
5544 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01005545
5546 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005547 "shapeFilter": shapeFilter,
5548 "rankFilter": rankFilter,
5549 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01005550 }
5551 return filterDict
5552
Kevin Cheng550ccc52021-03-03 11:21:43 -08005553 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005554 self,
5555 opName,
5556 shapeFilter=[None],
5557 rankFilter=None,
5558 dtypeFilter=None,
5559 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08005560 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005561
5562 try:
5563 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005564 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005565 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005566
5567 # Initialize a new random number generator
5568 self.rng = np.random.default_rng(self.random_seed)
5569
Kevin Cheng550ccc52021-03-03 11:21:43 -08005570 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005571
Eric Kunzee5e26762020-10-13 16:11:07 -07005572 # Test list consists of a tuple of:
5573 # (opName, testNameStr, dtype, shapeList, argumentsList)
5574 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005575 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005576 error_if_validators = op["error_if_validators"]
5577 else:
5578 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07005579
Matthew Haddon1c00b712021-10-01 15:51:03 +01005580 for validator in error_if_validators:
5581 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005582 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01005583 else:
5584 error_name = None
5585
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005586 filterDict = self.create_filter_lists(
5587 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
5588 )
5589 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01005590 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005591 cleanRankFilter = filterDict["rankFilter"]
5592 cleanDtypeFilter = filterDict["dtypeFilter"]
5593 cleanShapeFilter = filterDict["shapeFilter"]
5594 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005595
5596 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005597 for t in cleanDtypeFilter:
5598 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01005599 # Filter out by rank
5600 if shape is not None and len(shape) != r:
5601 continue
Matthew Haddon74567092021-07-16 15:38:20 +01005602 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005603 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005604
Matthew Haddon74567092021-07-16 15:38:20 +01005605 shapeStr = self.shapeStr(shapeList[0])
5606 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07005607
Matthew Haddon74567092021-07-16 15:38:20 +01005608 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
5609 argList = []
5610 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01005611 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07005612 else:
Matthew Haddon74567092021-07-16 15:38:20 +01005613 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07005614
Matthew Haddon74567092021-07-16 15:38:20 +01005615 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005616 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01005617 if argStr:
5618 testStr = "{}_{}_{}_{}".format(
5619 opName, shapeStr, typeStr, argStr
5620 )
5621 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005622 testStr = "{}_{}_{}".format(
5623 opName, shapeStr, typeStr
5624 )
5625 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01005626 if argStr:
5627 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
5628 opName, error_name, shapeStr, typeStr, argStr
5629 )
5630 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005631 testStr = "{}_ERRORIF_{}_{}_{}".format(
5632 opName, error_name, shapeStr, typeStr
5633 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005634
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005635 testList.append(
5636 (opName, testStr, t, error_name, shapeList, args)
5637 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005638
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005639 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01005640 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
5641 if "invalid_test_validators" in op:
5642 invalid_test_validators = op["invalid_test_validators"]
5643 clean_testList = []
5644 for test in testList:
5645 for validator_fcn in invalid_test_validators:
5646 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005647 if validator_fcn(
5648 opName=test[0],
5649 input_dtype=test[2],
5650 shapeList=test[4],
5651 args=test[5],
5652 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005653 remove_test = True
5654 if not remove_test:
5655 clean_testList.append(test)
5656 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07005657
5658 return testList
5659
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005660 def serializeTest(
5661 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
5662 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005663 try:
5664 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005665 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005666 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07005667
5668 # Create a serializer
5669 self.createSerializer(opName, testStr)
5670
Kevin Cheng550ccc52021-03-03 11:21:43 -08005671 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01005672 if "error_if_validators" in op:
5673 error_if_validators = op["error_if_validators"]
5674 else:
5675 error_if_validators = None
5676
Kevin Cheng550ccc52021-03-03 11:21:43 -08005677 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07005678 num_operands = pCount + cCount
5679
5680 if isinstance(dtype_or_dtypeList, list):
5681 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07005682 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005683 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07005684 else:
5685 dtypeList = [dtype_or_dtypeList] * (num_operands)
5686
Kevin Cheng93a16282021-08-31 16:14:03 -07005687 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01005688 assert (
5689 len(shapeList) == num_operands
5690 ), "shapeList length {} must match number of operands {}".format(
5691 len(shapeList), num_operands
5692 )
5693 assert (
5694 len(dtypeList) == num_operands
5695 ), "dtypeList length {} must match number of operands {}".format(
5696 len(dtypeList), num_operands
5697 )
Eric Kunzee5e26762020-10-13 16:11:07 -07005698
5699 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005700 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07005701 except KeyError:
5702 qgen = None
5703
5704 # Build the random tensor operands and the test
5705 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08005706
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005707 tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005708
5709 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005710 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01005711 else:
5712 qinfo = None
5713
5714 try:
5715 if error_if_validators is None:
5716 if qinfo is not None:
5717 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
5718 else:
5719 resultName = build_fcn(self, op, *tens, *testArgs)
5720 else:
5721 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005722 resultName = build_fcn(
5723 self,
5724 op,
5725 *tens,
5726 *testArgs,
5727 validator_fcns=error_if_validators,
5728 error_name=error_name,
5729 qinfo=qinfo,
5730 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005731 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005732 resultName = build_fcn(
5733 self,
5734 op,
5735 *tens,
5736 *testArgs,
5737 validator_fcns=error_if_validators,
5738 error_name=error_name,
5739 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01005740 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00005741 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005742 raise e
5743
Les Bell729b0352021-11-24 10:28:21 +00005744 if resultName:
5745 # The test is valid, serialize it
5746 self.serialize("test")
5747 else:
5748 # The test is not valid
5749 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01005750
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005751 def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
Matthew Haddon1c00b712021-10-01 15:51:03 +01005752 pCount, cCount = op["operands"]
5753
5754 tens = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005755 if (
5756 (op["op"] == Op.ADD or op["op"] == Op.SUB)
5757 and dtypeList[0] == DType.INT32
5758 and error_name is None
5759 ):
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005760 # Make sure the operation does not cause value saturation - where
5761 # the number wraps due to limited number of bits to store the answer
5762 assert (
5763 pCount == 2 and cCount == 0
5764 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005765 placeholders = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005766 add = op["op"] == Op.ADD
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005767 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5768 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5769 if add:
5770 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
5771 else:
5772 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
5773
5774 # Work out the saturation limits
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005775 max_i32 = (1 << 31) - 1
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005776 min_i32 = -(1 << 31)
5777 max_arr = np.full(shapeList[1], max_i32)
5778 min_arr = np.full(shapeList[1], min_i32)
5779
5780 # Find how much values exceed the maximum/minimums
5781 sat_max_arr = np.maximum(res_arr - max_arr, 0)
5782 sat_min_arr = np.minimum(res_arr - min_arr, 0)
5783
5784 if not add:
5785 # Swap saturation values and negate values as we need to perform opposite operations
5786 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
5787
5788 # Create new array of unsaturated values by clipping values as needed
5789 b_unsat_arr = b_arr
5790 if (sat_max_arr != 0).any():
5791 # Clip values that cause saturation
5792 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
5793 # Reduce axes in unsaturated tensor to match original tensor
5794 for axis, dim in enumerate(b_arr.shape):
5795 if dim != b_unsat_arr.shape[axis]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005796 assert (
5797 dim == 1
5798 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005799 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
5800
5801 if (sat_min_arr != 0).any():
5802 # Clip values that cause saturation
5803 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
5804 # Reduce axes in unsaturated tensor to match original tensor
5805 for axis, dim in enumerate(b_arr.shape):
5806 if dim != b_unsat_arr.shape[axis]:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005807 assert (
5808 dim == 1
5809 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005810 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
5811
5812 placeholders.append(
5813 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5814 )
5815 placeholders.append(
5816 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
5817 )
5818
5819 tens.extend(placeholders)
Jeremy Johnson0c171ac2022-01-24 12:24:21 +00005820 elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] in (
5821 DType.INT32,
5822 DType.INT16,
5823 DType.INT8,
5824 ):
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005825 # Limit input tensors with cond_if_binary or while_loop to stop
Jeremy Johnson0c171ac2022-01-24 12:24:21 +00005826 # saturation of add/sub ops with int32 and keep all logical shift
5827 # values between 0 to 31 for int16 or int8
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005828 pRemain = pCount
5829 placeholders = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005830 for idx, shape in enumerate(shapeList[:]):
Jeremy Johnson0c171ac2022-01-24 12:24:21 +00005831 if dtypeList[0] == DType.INT32:
5832 arr = self.getRandTensor(shapeList[idx], DType.INT16)
5833 else:
5834 arr = np.int32(
5835 self.rng.integers(low=0, high=32, size=shapeList[idx])
5836 )
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005837 if pRemain > 0:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005838 placeholders.append(
5839 self.ser.addPlaceholder(shape, dtypeList[idx], arr)
5840 )
Jeremy Johnson8c06a652021-10-20 15:51:11 +01005841 pRemain -= 1
5842 else:
5843 placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
5844
5845 tens.extend(placeholders)
Jeremy Johnsonef509a42021-09-07 13:59:47 +01005846 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
5847 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005848 assert (
5849 pCount == 2 and cCount == 0
5850 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08005851
5852 placeholders = []
5853 for idx, shape in enumerate(shapeList[:]):
5854 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07005855 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005856 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005857 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005858 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07005859 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08005860 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005861 elif error_name == ErrorIf.WrongInputType:
5862 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005863 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08005864 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08005865 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005866 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07005867 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08005868
5869 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01005870 elif op["op"] == Op.SELECT:
5871 # Set datatype of condition tensor to boolean
5872 dtypeList[0] = DType.BOOL
5873 tens.extend(
5874 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
5875 )
5876 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005877 elif op["op"] == Op.INTDIV and error_name is None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005878 assert (
5879 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01005880 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005881
5882 placeholders = []
5883
Matthew Haddon459443c2021-08-23 16:43:13 +01005884 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005885 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07005886 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005887 while True:
5888 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5889 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
5890
5891 if (divisor_arr == 0).any():
5892 continue
5893
Kevin Cheng47315e12021-05-13 17:41:28 -07005894 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005895 continue
5896
5897 break
5898
5899 placeholders.append(
5900 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
5901 )
5902 placeholders.append(
5903 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
5904 )
5905
5906 tens.extend(placeholders)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005907 elif op["op"] == Op.MUL and error_name is None:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005908 assert (
5909 pCount == 2 and cCount == 0
5910 ), "Op.MUL must have 2 placeholders, 0 consts"
5911
5912 if dtypeList[0] == DType.FLOAT:
5913 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
5914 else:
5915 placeholders = []
5916
5917 # Make sure multiply result in int32 range
5918 shift = testArgs[0]
5919 if dtypeList[0] == DType.INT8:
5920 num_bits = 8
5921 elif dtypeList[0] == DType.INT16:
5922 num_bits = 16
5923 elif dtypeList[0] == DType.INT32:
5924 num_bits = 32
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005925 elif error_name == ErrorIf.WrongInputType:
5926 num_bits = 8
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07005927 else:
5928 raise Exception("OpMul: invalid input dtype")
5929
5930 for idx, shape in enumerate(shapeList[:]):
5931 low = -(2 ** (num_bits - 1))
5932 high = (2 ** (num_bits - 1)) - 1
5933
5934 a_arr = np.int32(
5935 self.rng.integers(low=low, high=high, size=shapeList[0])
5936 )
5937 b_arr = np.int32(
5938 self.rng.integers(low=low, high=high, size=shapeList[1])
5939 )
5940
5941 i = 0
5942 while True:
5943
5944 a_arr_64 = a_arr.astype(np.int64)
5945 b_arr_64 = b_arr.astype(np.int64)
5946
5947 if shift > 0:
5948 rounding = 1 << (shift - 1)
5949 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
5950 else:
5951 result_arr = a_arr_64 * b_arr_64
5952
5953 if (result_arr > -(2 ** 31)).all() and (
5954 result_arr <= ((2 ** 31) - 1)
5955 ).all():
5956 break
5957
5958 i = i + 1
5959 a_arr = a_arr // 2
5960 b_arr = b_arr // 2
5961
5962 placeholders.append(
5963 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
5964 )
5965 placeholders.append(
5966 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
5967 )
5968
5969 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01005970 elif op["op"] == Op.CONCAT:
5971 count = len(shapeList) - self.args.num_const_inputs_concat
5972 if count < 1:
5973 count = 1
5974 if self.args.num_const_inputs_concat == 0:
5975 count = len(shapeList)
5976
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005977 # Ensure axis is an int
5978 testArgs[0] = int(testArgs[0])
5979
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005980 shapeList = TosaTensorGen.tgConcatConstInput(
5981 self, shapeList, testArgs[0], error_name
5982 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005983
Matthew Haddon818ab902021-07-27 09:12:49 +01005984 tens.extend(
5985 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
5986 )
5987 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Jeremy Johnson66bad802022-01-18 14:48:35 +00005988 elif op["op"] == Op.LOGICAL_LEFT_SHIFT or op["op"] == Op.LOGICAL_RIGHT_SHIFT:
5989 assert (
5990 pCount == 2 and cCount == 0
5991 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
5992 values_arr = self.getRandTensor(shapeList[0], dtypeList[0])
5993 shift_arr = np.int32(self.rng.integers(low=0, high=32, size=shapeList[1]))
5994 placeholders = []
5995 placeholders.append(
5996 self.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
5997 )
5998 placeholders.append(
5999 self.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
6000 )
6001 tens.extend(placeholders)
Kevin Chengaee1fac2020-11-11 13:54:06 -08006002 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07006003 tens.extend(
6004 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
6005 )
6006 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07006007
Matthew Haddon1c00b712021-10-01 15:51:03 +01006008 return tens
Eric Kunzee5e26762020-10-13 16:11:07 -07006009
6010 def createDynamicOpLists(self):
6011
6012 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07006013 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07006014
Kevin Cheng1533b852021-09-01 12:51:58 -07006015 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006016 testName = "conv2d_{}x{}".format(k[0], k[1])
6017 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
6018 self.TOSA_OP_LIST[testName]["filter"] = k
6019 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07006020
Kevin Cheng550ccc52021-03-03 11:21:43 -08006021 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
6022 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
6023 "depthwise_conv2d_TEMPLATE"
6024 ].copy()
6025 self.TOSA_OP_LIST[testName]["filter"] = k
6026 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07006027
Kevin Cheng550ccc52021-03-03 11:21:43 -08006028 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
6029 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
6030 "transpose_conv2d_TEMPLATE"
6031 ].copy()
6032 self.TOSA_OP_LIST[testName]["filter"] = k
6033 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07006034
Kevin Cheng1533b852021-09-01 12:51:58 -07006035 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
6036 for k in KERNELS_3D:
6037 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
6038 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
6039 self.TOSA_OP_LIST[testName]["filter"] = k
6040 self.TOSA_OP_LIST[testName]["template"] = False
6041
Eric Kunzee5e26762020-10-13 16:11:07 -07006042 # Delete any templates after having created any dynamic ops
6043 # This is a two-pass operation because it's bad practice to delete
6044 # keys from dictionaries while iterating
6045 keyList = []
6046 for k in self.TOSA_OP_LIST:
6047 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006048 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07006049 keyList.append(k)
6050 continue
6051 except KeyError:
6052 pass
6053
6054 for k in keyList:
6055 del self.TOSA_OP_LIST[k]
6056
6057 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006058 """Fill in default fields for ops if they aren't already specified.
6059 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07006060 for op in self.TOSA_OP_LIST:
6061
6062 # Required fields
6063 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006064 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07006065 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006066 raise Exception(
6067 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
6068 )
Eric Kunzee5e26762020-10-13 16:11:07 -07006069
6070 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006071 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07006072 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08006073 raise Exception(
6074 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
6075 op
6076 )
6077 )
Eric Kunzee5e26762020-10-13 16:11:07 -07006078
6079 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006080 _ = self.TOSA_OP_LIST[op]["types"]
6081 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006082 raise Exception(
6083 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
6084 )
Eric Kunzee5e26762020-10-13 16:11:07 -07006085
6086 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006087 _ = self.TOSA_OP_LIST[op]["op"]
6088 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006089 raise Exception(
6090 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
6091 )
Eric Kunzee5e26762020-10-13 16:11:07 -07006092
6093 # Put in default rank range, if missing
6094 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006095 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07006096 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08006097 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07006098
6099 # Tensor operator list
6100 # 'op': op name
6101 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08006102 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
6103 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07006104 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
6105 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08006106 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07006107
Kevin Cheng550ccc52021-03-03 11:21:43 -08006108 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
6109 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07006110
Kevin Cheng550ccc52021-03-03 11:21:43 -08006111 TYPE_BOOL = [DType.BOOL]
6112 TYPE_FI32 = [DType.FLOAT, DType.INT32]
6113 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
6114 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07006115
Kevin Cheng550ccc52021-03-03 11:21:43 -08006116 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07006117
Kevin Cheng1533b852021-09-01 12:51:58 -07006118 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07006119 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07006120 [DType.INT8, DType.INT8, DType.INT32],
6121 [DType.INT16, DType.INT8, DType.INT48],
6122 DType.FLOAT,
6123 ]
6124
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01006125 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07006126
6127 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08006128 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08006129 "argmax": {
6130 "op": Op.ARGMAX,
6131 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01006132 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006133 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6134 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006135 "error_if_validators": (
6136 TosaErrorValidator.evAxisSmallerZero,
6137 TosaErrorValidator.evAxisLargerRank,
6138 TosaErrorValidator.evArgmaxOutputRankMismatch,
6139 TosaErrorValidator.evArgmaxOutputShapeMismatch,
6140 TosaErrorValidator.evWrongRank,
6141 TosaErrorValidator.evWrongInputType,
6142 TosaErrorValidator.evWrongOutputType,
6143 TosaErrorValidator.evWrongInputList,
6144 TosaErrorValidator.evWrongOutputList,
6145 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006147 "avg_pool2d": {
6148 "op": Op.AVG_POOL2D,
6149 "operands": (1, 0),
6150 "rank": (4, 4),
6151 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
6152 "qgen": TosaQuantGen.qgUnary,
6153 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00006154 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006155 "error_if_validators": (
6156 TosaErrorValidator.evKernelSmallerOne,
6157 TosaErrorValidator.evStrideSmallerOne,
6158 TosaErrorValidator.evPadSmallerZero,
6159 TosaErrorValidator.evWrongRank,
6160 TosaErrorValidator.evWrongInputType,
6161 TosaErrorValidator.evWrongOutputType,
6162 TosaErrorValidator.evWrongInputList,
6163 TosaErrorValidator.evWrongOutputList,
6164 TosaErrorValidator.evInputZeroPointNotZero,
6165 TosaErrorValidator.evOutputZeroPointNotZero,
6166 TosaErrorValidator.evPadLargerEqualKernel,
6167 TosaErrorValidator.evPoolingOutputShapeMismatch,
6168 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006169 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006170 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08006171 "conv2d_TEMPLATE": {
6172 "op": Op.CONV2D,
6173 "operands": (1, 2),
6174 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01006175 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006176 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07006177 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00006178 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
6179 "error_if_validators": (
6180 TosaErrorValidator.evWrongInputType,
6181 TosaErrorValidator.evWrongOutputType,
6182 TosaErrorValidator.evWrongInputList,
6183 TosaErrorValidator.evWrongOutputList,
6184 TosaErrorValidator.evInputZeroPointNotZero,
6185 TosaErrorValidator.evWeightZeroPointNotZero,
6186 TosaErrorValidator.evPadSmallerZero,
6187 TosaErrorValidator.evStrideSmallerOne,
6188 TosaErrorValidator.evDilationSmallerOne,
6189 TosaErrorValidator.evWrongRank,
6190 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006191 "template": True,
6192 },
Kevin Cheng1533b852021-09-01 12:51:58 -07006193 # Templated operator. Filled in by createDynamicOpLists
6194 "conv3d_TEMPLATE": {
6195 "op": Op.CONV3D,
6196 "operands": (1, 2),
6197 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01006198 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07006199 "qgen": TosaQuantGen.qgConv,
6200 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00006201 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
6202 "error_if_validators": (
6203 TosaErrorValidator.evWrongInputType,
6204 TosaErrorValidator.evWrongOutputType,
6205 TosaErrorValidator.evWrongInputList,
6206 TosaErrorValidator.evWrongOutputList,
6207 TosaErrorValidator.evInputZeroPointNotZero,
6208 TosaErrorValidator.evWeightZeroPointNotZero,
6209 TosaErrorValidator.evPadSmallerZero,
6210 TosaErrorValidator.evStrideSmallerOne,
6211 TosaErrorValidator.evDilationSmallerOne,
6212 TosaErrorValidator.evWrongRank,
6213 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07006214 "template": True,
6215 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006216 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08006217 "depthwise_conv2d_TEMPLATE": {
6218 "op": Op.DEPTHWISE_CONV2D,
6219 "operands": (1, 2),
6220 "filter": [1, 1],
6221 "rank": (4, 4),
6222 "build_fcn": (
6223 build_depthwise_conv2d,
6224 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01006225 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08006226 ),
6227 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07006228 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00006229 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
6230 "error_if_validators": (
6231 TosaErrorValidator.evWrongInputType,
6232 TosaErrorValidator.evWrongOutputType,
6233 TosaErrorValidator.evWrongInputList,
6234 TosaErrorValidator.evWrongOutputList,
6235 TosaErrorValidator.evInputZeroPointNotZero,
6236 TosaErrorValidator.evWeightZeroPointNotZero,
6237 TosaErrorValidator.evPadSmallerZero,
6238 TosaErrorValidator.evStrideSmallerOne,
6239 TosaErrorValidator.evDilationSmallerOne,
6240 TosaErrorValidator.evWrongRank,
6241 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006242 "template": True,
6243 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006244 "fully_connected": {
6245 "op": Op.FULLY_CONNECTED,
6246 "operands": (1, 2),
6247 "rank": (2, 2),
6248 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
6249 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07006250 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006251 "error_if_validators": (
6252 TosaErrorValidator.evInputZeroPointNotZero,
6253 TosaErrorValidator.evWeightZeroPointNotZero,
6254 TosaErrorValidator.evWrongRank,
6255 TosaErrorValidator.evWrongInputType,
6256 TosaErrorValidator.evWrongOutputType,
6257 TosaErrorValidator.evWrongInputList,
6258 TosaErrorValidator.evWrongOutputList,
6259 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006260 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006261 "matmul": {
6262 "op": Op.MATMUL,
6263 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07006264 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08006265 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
6266 "qgen": TosaQuantGen.qgMatmul,
6267 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006268 "error_if_validators": (
6269 TosaErrorValidator.evInputZeroPointNotZero,
6270 TosaErrorValidator.evWrongRank,
6271 TosaErrorValidator.evWrongInputType,
6272 TosaErrorValidator.evWrongOutputType,
6273 TosaErrorValidator.evWrongInputList,
6274 TosaErrorValidator.evWrongOutputList,
6275 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006276 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006277 "max_pool2d": {
6278 "op": Op.MAX_POOL2D,
6279 "operands": (1, 0),
6280 "rank": (4, 4),
6281 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
6282 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00006283 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006284 "error_if_validators": (
6285 TosaErrorValidator.evKernelSmallerOne,
6286 TosaErrorValidator.evStrideSmallerOne,
6287 TosaErrorValidator.evPadSmallerZero,
6288 TosaErrorValidator.evWrongRank,
6289 TosaErrorValidator.evWrongInputType,
6290 TosaErrorValidator.evWrongOutputType,
6291 TosaErrorValidator.evWrongInputList,
6292 TosaErrorValidator.evWrongOutputList,
6293 TosaErrorValidator.evPadLargerEqualKernel,
6294 TosaErrorValidator.evPoolingOutputShapeMismatch,
6295 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006296 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006297 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08006298 "transpose_conv2d_TEMPLATE": {
6299 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07006300 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006301 "rank": (4, 4),
6302 "build_fcn": (
6303 build_transpose_conv2d,
6304 TosaTensorGen.tgTransposeConv2D,
6305 TosaArgGen.agTransposeConv2D,
6306 ),
6307 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07006308 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00006309 "invalid_test_validators": (
6310 TosaInvalidValidator.ivHeightWidthInvalid,
6311 TosaInvalidValidator.ivNonPositiveOutputShape,
6312 ),
6313 "error_if_validators": (
6314 TosaErrorValidator.evWrongInputType,
6315 TosaErrorValidator.evWrongOutputType,
6316 TosaErrorValidator.evWrongInputList,
6317 TosaErrorValidator.evWrongOutputList,
6318 TosaErrorValidator.evInputZeroPointNotZero,
6319 TosaErrorValidator.evWeightZeroPointNotZero,
6320 TosaErrorValidator.evPadSmallerZero,
6321 TosaErrorValidator.evStrideSmallerOne,
6322 TosaErrorValidator.evDilationSmallerOne,
6323 TosaErrorValidator.evWrongRank,
6324 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006325 "template": True,
6326 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006327 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08006328 "clamp": {
6329 "op": Op.CLAMP,
6330 "operands": (1, 0),
6331 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
6332 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006333 "error_if_validators": (
6334 TosaErrorValidator.evMaxSmallerMin,
6335 TosaErrorValidator.evWrongInputType,
6336 TosaErrorValidator.evWrongOutputType,
6337 TosaErrorValidator.evWrongInputList,
6338 TosaErrorValidator.evWrongOutputList,
6339 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006340 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08006341 "sigmoid": {
6342 "op": Op.SIGMOID,
6343 "operands": (1, 0),
6344 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
6345 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006346 "error_if_validators": (
6347 TosaErrorValidator.evWrongInputType,
6348 TosaErrorValidator.evWrongOutputType,
6349 TosaErrorValidator.evWrongInputList,
6350 TosaErrorValidator.evWrongOutputList,
6351 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006352 },
6353 "tanh": {
6354 "op": Op.TANH,
6355 "operands": (1, 0),
6356 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
6357 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006358 "error_if_validators": (
6359 TosaErrorValidator.evWrongInputType,
6360 TosaErrorValidator.evWrongOutputType,
6361 TosaErrorValidator.evWrongInputList,
6362 TosaErrorValidator.evWrongOutputList,
6363 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006364 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006365 # Elementwise Binary Operators
6366 "add": {
6367 "op": Op.ADD,
6368 "operands": (2, 0),
6369 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6370 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006371 "error_if_validators": (
6372 TosaErrorValidator.evRankMismatch,
6373 TosaErrorValidator.evWrongInputType,
6374 TosaErrorValidator.evWrongOutputType,
6375 TosaErrorValidator.evWrongInputList,
6376 TosaErrorValidator.evWrongOutputList,
6377 TosaErrorValidator.evDimensionMismatch,
6378 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006379 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006380 "arithmetic_right_shift": {
6381 "op": Op.ARITHMETIC_RIGHT_SHIFT,
6382 "operands": (2, 0),
6383 "build_fcn": (
6384 build_arithmetic_right_shift,
6385 TosaTensorGen.tgBroadcastFuzz,
6386 TosaArgGen.agArithmeticRightShift,
6387 ),
6388 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006389 "error_if_validators": (
6390 TosaErrorValidator.evRankMismatch,
6391 TosaErrorValidator.evWrongInputType,
6392 TosaErrorValidator.evWrongOutputType,
6393 TosaErrorValidator.evWrongInputList,
6394 TosaErrorValidator.evWrongOutputList,
6395 TosaErrorValidator.evDimensionMismatch,
6396 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006397 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006398 "bitwise_and": {
6399 "op": Op.BITWISE_AND,
6400 "operands": (2, 0),
6401 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6402 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006403 "error_if_validators": (
6404 TosaErrorValidator.evRankMismatch,
6405 TosaErrorValidator.evWrongInputType,
6406 TosaErrorValidator.evWrongOutputType,
6407 TosaErrorValidator.evWrongInputList,
6408 TosaErrorValidator.evWrongOutputList,
6409 TosaErrorValidator.evDimensionMismatch,
6410 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006411 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006412 "bitwise_or": {
6413 "op": Op.BITWISE_OR,
6414 "operands": (2, 0),
6415 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6416 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006417 "error_if_validators": (
6418 TosaErrorValidator.evRankMismatch,
6419 TosaErrorValidator.evWrongInputType,
6420 TosaErrorValidator.evWrongOutputType,
6421 TosaErrorValidator.evWrongInputList,
6422 TosaErrorValidator.evWrongOutputList,
6423 TosaErrorValidator.evDimensionMismatch,
6424 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006425 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006426 "bitwise_xor": {
6427 "op": Op.BITWISE_XOR,
6428 "operands": (2, 0),
6429 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6430 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006431 "error_if_validators": (
6432 TosaErrorValidator.evRankMismatch,
6433 TosaErrorValidator.evWrongInputType,
6434 TosaErrorValidator.evWrongOutputType,
6435 TosaErrorValidator.evWrongInputList,
6436 TosaErrorValidator.evWrongOutputList,
6437 TosaErrorValidator.evDimensionMismatch,
6438 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006439 },
Matthew Haddon459443c2021-08-23 16:43:13 +01006440 "intdiv": {
6441 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07006442 "operands": (2, 0),
6443 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6444 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006445 "error_if_validators": (
6446 TosaErrorValidator.evRankMismatch,
6447 TosaErrorValidator.evWrongInputType,
6448 TosaErrorValidator.evWrongOutputType,
6449 TosaErrorValidator.evWrongInputList,
6450 TosaErrorValidator.evWrongOutputList,
6451 TosaErrorValidator.evDimensionMismatch,
6452 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07006453 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006454 "logical_and": {
6455 "op": Op.LOGICAL_AND,
6456 "operands": (2, 0),
6457 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6458 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006459 "error_if_validators": (
6460 TosaErrorValidator.evRankMismatch,
6461 TosaErrorValidator.evWrongInputType,
6462 TosaErrorValidator.evWrongOutputType,
6463 TosaErrorValidator.evWrongInputList,
6464 TosaErrorValidator.evWrongOutputList,
6465 TosaErrorValidator.evDimensionMismatch,
6466 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006467 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006468 "logical_left_shift": {
6469 "op": Op.LOGICAL_LEFT_SHIFT,
6470 "operands": (2, 0),
6471 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6472 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006473 "error_if_validators": (
6474 TosaErrorValidator.evRankMismatch,
6475 TosaErrorValidator.evWrongInputType,
6476 TosaErrorValidator.evWrongOutputType,
6477 TosaErrorValidator.evWrongInputList,
6478 TosaErrorValidator.evWrongOutputList,
6479 TosaErrorValidator.evDimensionMismatch,
6480 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006481 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006482 "logical_right_shift": {
6483 "op": Op.LOGICAL_RIGHT_SHIFT,
6484 "operands": (2, 0),
6485 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6486 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006487 "error_if_validators": (
6488 TosaErrorValidator.evRankMismatch,
6489 TosaErrorValidator.evWrongInputType,
6490 TosaErrorValidator.evWrongOutputType,
6491 TosaErrorValidator.evWrongInputList,
6492 TosaErrorValidator.evWrongOutputList,
6493 TosaErrorValidator.evDimensionMismatch,
6494 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006495 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006496 "logical_or": {
6497 "op": Op.LOGICAL_OR,
6498 "operands": (2, 0),
6499 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6500 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006501 "error_if_validators": (
6502 TosaErrorValidator.evRankMismatch,
6503 TosaErrorValidator.evWrongInputType,
6504 TosaErrorValidator.evWrongOutputType,
6505 TosaErrorValidator.evWrongInputList,
6506 TosaErrorValidator.evWrongOutputList,
6507 TosaErrorValidator.evDimensionMismatch,
6508 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006509 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006510 "logical_xor": {
6511 "op": Op.LOGICAL_XOR,
6512 "operands": (2, 0),
6513 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6514 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006515 "error_if_validators": (
6516 TosaErrorValidator.evRankMismatch,
6517 TosaErrorValidator.evWrongInputType,
6518 TosaErrorValidator.evWrongOutputType,
6519 TosaErrorValidator.evWrongInputList,
6520 TosaErrorValidator.evWrongOutputList,
6521 TosaErrorValidator.evDimensionMismatch,
6522 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006523 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006524 "maximum": {
6525 "op": Op.MAXIMUM,
6526 "operands": (2, 0),
6527 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6528 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006529 "error_if_validators": (
6530 TosaErrorValidator.evRankMismatch,
6531 TosaErrorValidator.evWrongInputType,
6532 TosaErrorValidator.evWrongOutputType,
6533 TosaErrorValidator.evWrongInputList,
6534 TosaErrorValidator.evWrongOutputList,
6535 TosaErrorValidator.evDimensionMismatch,
6536 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006537 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006538 "minimum": {
6539 "op": Op.MINIMUM,
6540 "operands": (2, 0),
6541 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6542 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006543 "error_if_validators": (
6544 TosaErrorValidator.evRankMismatch,
6545 TosaErrorValidator.evWrongInputType,
6546 TosaErrorValidator.evWrongOutputType,
6547 TosaErrorValidator.evWrongInputList,
6548 TosaErrorValidator.evWrongOutputList,
6549 TosaErrorValidator.evDimensionMismatch,
6550 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006551 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006552 "mul": {
6553 "op": Op.MUL,
6554 "operands": (2, 0),
6555 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
6556 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006557 "error_if_validators": (
6558 TosaErrorValidator.evWrongInputType,
6559 TosaErrorValidator.evWrongOutputType,
6560 TosaErrorValidator.evWrongInputList,
6561 TosaErrorValidator.evWrongOutputList,
6562 TosaErrorValidator.evRankMismatch,
6563 TosaErrorValidator.evDimensionMismatch,
6564 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006565 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006566 "pow": {
6567 "op": Op.POW,
6568 "operands": (2, 0),
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00006569 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08006570 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006571 "error_if_validators": (
6572 TosaErrorValidator.evRankMismatch,
6573 TosaErrorValidator.evWrongInputType,
6574 TosaErrorValidator.evWrongOutputType,
6575 TosaErrorValidator.evWrongInputList,
6576 TosaErrorValidator.evWrongOutputList,
6577 TosaErrorValidator.evDimensionMismatch,
6578 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006579 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006580 "sub": {
6581 "op": Op.SUB,
6582 "operands": (2, 0),
6583 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
6584 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006585 "error_if_validators": (
6586 TosaErrorValidator.evRankMismatch,
6587 TosaErrorValidator.evWrongInputType,
6588 TosaErrorValidator.evWrongOutputType,
6589 TosaErrorValidator.evWrongInputList,
6590 TosaErrorValidator.evWrongOutputList,
6591 TosaErrorValidator.evDimensionMismatch,
6592 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006593 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006594 "table": {
6595 "op": Op.TABLE,
6596 # Use the automatic generation functions to create the input array
6597 # but create the table tensor in the build function, as it may be
6598 # a different type from the input
6599 "operands": (1, 0),
Kevin Chengfe392ce2021-10-18 21:51:55 +00006600 "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01006601 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006602 "error_if_validators": (
6603 TosaErrorValidator.evWrongInputType,
6604 TosaErrorValidator.evWrongOutputType,
6605 TosaErrorValidator.evWrongInputList,
6606 TosaErrorValidator.evWrongOutputList,
6607 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006608 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006609 # Elementwise Unary operators
6610 "abs": {
6611 "op": Op.ABS,
6612 "operands": (1, 0),
6613 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6614 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006615 "error_if_validators": (
6616 TosaErrorValidator.evWrongInputType,
6617 TosaErrorValidator.evWrongOutputType,
6618 TosaErrorValidator.evWrongInputList,
6619 TosaErrorValidator.evWrongOutputList,
6620 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006621 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006622 "bitwise_not": {
6623 "op": Op.BITWISE_NOT,
6624 "operands": (1, 0),
6625 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6626 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006627 "error_if_validators": (
6628 TosaErrorValidator.evWrongInputType,
6629 TosaErrorValidator.evWrongOutputType,
6630 TosaErrorValidator.evWrongInputList,
6631 TosaErrorValidator.evWrongOutputList,
6632 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006633 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006634 "ceil": {
6635 "op": Op.CEIL,
6636 "operands": (1, 0),
6637 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6638 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006639 "error_if_validators": (
6640 TosaErrorValidator.evWrongInputType,
6641 TosaErrorValidator.evWrongOutputType,
6642 TosaErrorValidator.evWrongInputList,
6643 TosaErrorValidator.evWrongOutputList,
6644 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006645 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006646 "clz": {
6647 "op": Op.CLZ,
6648 "operands": (1, 0),
6649 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6650 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006651 "error_if_validators": (
6652 TosaErrorValidator.evWrongInputType,
6653 TosaErrorValidator.evWrongOutputType,
6654 TosaErrorValidator.evWrongInputList,
6655 TosaErrorValidator.evWrongOutputList,
6656 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006657 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006658 "exp": {
6659 "op": Op.EXP,
6660 "operands": (1, 0),
6661 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6662 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006663 "error_if_validators": (
6664 TosaErrorValidator.evWrongInputType,
6665 TosaErrorValidator.evWrongOutputType,
6666 TosaErrorValidator.evWrongInputList,
6667 TosaErrorValidator.evWrongOutputList,
6668 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006669 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006670 "floor": {
6671 "op": Op.FLOOR,
6672 "operands": (1, 0),
6673 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6674 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006675 "error_if_validators": (
6676 TosaErrorValidator.evWrongInputType,
6677 TosaErrorValidator.evWrongOutputType,
6678 TosaErrorValidator.evWrongInputList,
6679 TosaErrorValidator.evWrongOutputList,
6680 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006681 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006682 "log": {
6683 "op": Op.LOG,
6684 "operands": (1, 0),
6685 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6686 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006687 "error_if_validators": (
6688 TosaErrorValidator.evWrongInputType,
6689 TosaErrorValidator.evWrongOutputType,
6690 TosaErrorValidator.evWrongInputList,
6691 TosaErrorValidator.evWrongOutputList,
6692 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006693 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006694 "logical_not": {
6695 "op": Op.LOGICAL_NOT,
6696 "operands": (1, 0),
6697 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6698 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006699 "error_if_validators": (
6700 TosaErrorValidator.evWrongInputType,
6701 TosaErrorValidator.evWrongOutputType,
6702 TosaErrorValidator.evWrongInputList,
6703 TosaErrorValidator.evWrongOutputList,
6704 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006705 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006706 "negate": {
6707 "op": Op.NEGATE,
6708 "operands": (1, 0),
6709 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6710 "qgen": TosaQuantGen.qgUnary,
6711 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006712 "error_if_validators": (
6713 TosaErrorValidator.evInputZeroPointNotZero,
6714 TosaErrorValidator.evOutputZeroPointNotZero,
6715 TosaErrorValidator.evWrongInputType,
6716 TosaErrorValidator.evWrongOutputType,
6717 TosaErrorValidator.evWrongInputList,
6718 TosaErrorValidator.evWrongOutputList,
6719 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006720 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006721 "reciprocal": {
6722 "op": Op.RECIPROCAL,
6723 "operands": (1, 0),
6724 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6725 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006726 "error_if_validators": (
6727 TosaErrorValidator.evWrongInputType,
6728 TosaErrorValidator.evWrongOutputType,
6729 TosaErrorValidator.evWrongInputList,
6730 TosaErrorValidator.evWrongOutputList,
6731 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006732 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006733 "rsqrt": {
6734 "op": Op.RSQRT,
6735 "operands": (1, 0),
6736 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
6737 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006738 "error_if_validators": (
6739 TosaErrorValidator.evWrongInputType,
6740 TosaErrorValidator.evWrongOutputType,
6741 TosaErrorValidator.evWrongInputList,
6742 TosaErrorValidator.evWrongOutputList,
6743 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006744 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006745 # Elementwise Ternary operators
6746 "select": {
6747 "op": Op.SELECT,
6748 "operands": (3, 0),
6749 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
6750 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006751 "error_if_validators": (
6752 TosaErrorValidator.evRankMismatch,
6753 TosaErrorValidator.evWrongInputType,
6754 TosaErrorValidator.evWrongOutputType,
6755 TosaErrorValidator.evWrongInputList,
6756 TosaErrorValidator.evWrongOutputList,
6757 TosaErrorValidator.evDimensionMismatch,
6758 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006759 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006760 # Comparison operators
6761 "equal": {
6762 "op": Op.EQUAL,
6763 "operands": (2, 0),
6764 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
6765 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006766 "error_if_validators": (
6767 TosaErrorValidator.evRankMismatch,
6768 TosaErrorValidator.evWrongInputType,
6769 TosaErrorValidator.evWrongOutputType,
6770 TosaErrorValidator.evWrongInputList,
6771 TosaErrorValidator.evWrongOutputList,
6772 TosaErrorValidator.evDimensionMismatch,
6773 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006774 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006775 "greater_equal": {
6776 "op": Op.GREATER_EQUAL,
6777 "operands": (2, 0),
6778 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
6779 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006780 "error_if_validators": (
6781 TosaErrorValidator.evRankMismatch,
6782 TosaErrorValidator.evWrongInputType,
6783 TosaErrorValidator.evWrongOutputType,
6784 TosaErrorValidator.evWrongInputList,
6785 TosaErrorValidator.evWrongOutputList,
6786 TosaErrorValidator.evDimensionMismatch,
6787 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006788 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006789 "greater": {
6790 "op": Op.GREATER,
6791 "operands": (2, 0),
6792 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
6793 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006794 "error_if_validators": (
6795 TosaErrorValidator.evRankMismatch,
6796 TosaErrorValidator.evWrongInputType,
6797 TosaErrorValidator.evWrongOutputType,
6798 TosaErrorValidator.evWrongInputList,
6799 TosaErrorValidator.evWrongOutputList,
6800 TosaErrorValidator.evDimensionMismatch,
6801 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006802 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006803 # Reduction operators
6804 "reduce_all": {
6805 "op": Op.REDUCE_ALL,
6806 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006807 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006808 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6809 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006810 "error_if_validators": (
6811 TosaErrorValidator.evAxisLargerRank,
6812 TosaErrorValidator.evAxisSmallerZero,
6813 TosaErrorValidator.evShapeOfAxisNotOne,
6814 TosaErrorValidator.evWrongInputType,
6815 TosaErrorValidator.evWrongOutputType,
6816 TosaErrorValidator.evWrongRank,
6817 TosaErrorValidator.evWrongInputList,
6818 TosaErrorValidator.evWrongOutputList,
6819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006820 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006821 "reduce_any": {
6822 "op": Op.REDUCE_ANY,
6823 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006824 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006825 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6826 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006827 "error_if_validators": (
6828 TosaErrorValidator.evAxisLargerRank,
6829 TosaErrorValidator.evAxisSmallerZero,
6830 TosaErrorValidator.evShapeOfAxisNotOne,
6831 TosaErrorValidator.evWrongInputType,
6832 TosaErrorValidator.evWrongOutputType,
6833 TosaErrorValidator.evWrongRank,
6834 TosaErrorValidator.evWrongInputList,
6835 TosaErrorValidator.evWrongOutputList,
6836 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006837 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006838 "reduce_max": {
6839 "op": Op.REDUCE_MAX,
6840 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006841 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006842 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6843 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006844 "error_if_validators": (
6845 TosaErrorValidator.evAxisLargerRank,
6846 TosaErrorValidator.evAxisSmallerZero,
6847 TosaErrorValidator.evShapeOfAxisNotOne,
6848 TosaErrorValidator.evWrongInputType,
6849 TosaErrorValidator.evWrongOutputType,
6850 TosaErrorValidator.evWrongRank,
6851 TosaErrorValidator.evWrongInputList,
6852 TosaErrorValidator.evWrongOutputList,
6853 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006854 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006855 "reduce_min": {
6856 "op": Op.REDUCE_MAX,
6857 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006858 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006859 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6860 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006861 "error_if_validators": (
6862 TosaErrorValidator.evAxisLargerRank,
6863 TosaErrorValidator.evAxisSmallerZero,
6864 TosaErrorValidator.evShapeOfAxisNotOne,
6865 TosaErrorValidator.evWrongInputType,
6866 TosaErrorValidator.evWrongOutputType,
6867 TosaErrorValidator.evWrongRank,
6868 TosaErrorValidator.evWrongInputList,
6869 TosaErrorValidator.evWrongOutputList,
6870 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006871 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006872 "reduce_product": {
6873 "op": Op.REDUCE_PRODUCT,
6874 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006875 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006876 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6877 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006878 "error_if_validators": (
6879 TosaErrorValidator.evAxisLargerRank,
6880 TosaErrorValidator.evAxisSmallerZero,
6881 TosaErrorValidator.evShapeOfAxisNotOne,
6882 TosaErrorValidator.evWrongInputType,
6883 TosaErrorValidator.evWrongOutputType,
6884 TosaErrorValidator.evWrongRank,
6885 TosaErrorValidator.evWrongInputList,
6886 TosaErrorValidator.evWrongOutputList,
6887 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006888 },
Jared Smolens573ecd42021-03-04 15:24:10 -08006889 "reduce_sum": {
6890 "op": Op.REDUCE_SUM,
6891 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00006892 "rank": (1, 4),
Jared Smolens573ecd42021-03-04 15:24:10 -08006893 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6894 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006895 "error_if_validators": (
6896 TosaErrorValidator.evAxisLargerRank,
6897 TosaErrorValidator.evAxisSmallerZero,
6898 TosaErrorValidator.evShapeOfAxisNotOne,
6899 TosaErrorValidator.evWrongInputType,
6900 TosaErrorValidator.evWrongOutputType,
6901 TosaErrorValidator.evWrongRank,
6902 TosaErrorValidator.evWrongInputList,
6903 TosaErrorValidator.evWrongOutputList,
6904 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08006905 },
Eric Kunzee5e26762020-10-13 16:11:07 -07006906 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08006907 "concat": {
6908 "op": Op.CONCAT,
6909 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01006910 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006911 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006912 "error_if_validators": (
6913 TosaErrorValidator.evAxisLargerRank,
6914 TosaErrorValidator.evAxisSmallerZero,
6915 TosaErrorValidator.evConcatInputRankMismatch,
6916 TosaErrorValidator.evConcatShapeSumMismatch,
6917 TosaErrorValidator.evConcatInputDimMismatch,
6918 TosaErrorValidator.evWrongInputType,
6919 TosaErrorValidator.evWrongOutputType,
6920 TosaErrorValidator.evWrongOutputList,
6921 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006922 },
6923 "pad": {
6924 "op": Op.PAD,
6925 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006926 "rank": (1, 5),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006927 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
6928 "qgen": TosaQuantGen.qgPad,
6929 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006930 "error_if_validators": (
6931 TosaErrorValidator.evWrongInputType,
6932 TosaErrorValidator.evPadSmallerZero,
6933 TosaErrorValidator.evWrongOutputType,
6934 TosaErrorValidator.evWrongInputList,
6935 TosaErrorValidator.evWrongOutputList,
6936 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006937 },
6938 "reshape": {
6939 "op": Op.RESHAPE,
6940 "operands": (1, 0),
6941 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
6942 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006943 "error_if_validators": (
6944 TosaErrorValidator.evTensorSizeInputOutputMismatch,
6945 TosaErrorValidator.evWrongInputType,
6946 TosaErrorValidator.evWrongOutputType,
6947 TosaErrorValidator.evWrongInputList,
6948 TosaErrorValidator.evWrongOutputList,
6949 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006950 },
6951 "reverse": {
6952 "op": Op.REVERSE,
6953 "operands": (1, 0),
6954 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
6955 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006956 "error_if_validators": (
6957 TosaErrorValidator.evAxisSmallerZero,
6958 TosaErrorValidator.evAxisLargerRank,
6959 TosaErrorValidator.evWrongInputType,
6960 TosaErrorValidator.evWrongOutputType,
6961 TosaErrorValidator.evWrongInputList,
6962 TosaErrorValidator.evWrongOutputList,
6963 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006964 },
6965 "slice": {
6966 "op": Op.SLICE,
6967 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01006968 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006969 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
6970 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006971 "error_if_validators": (
6972 TosaErrorValidator.evStartSmallerZero,
6973 TosaErrorValidator.evSizeSmallerEqualZero,
6974 TosaErrorValidator.evStartSizeOutsideBounds,
6975 TosaErrorValidator.evSizeOutputShapeMismatch,
6976 TosaErrorValidator.evInputSizeStartLengthMismatch,
6977 TosaErrorValidator.evWrongRank,
6978 TosaErrorValidator.evWrongInputType,
6979 TosaErrorValidator.evWrongOutputType,
6980 TosaErrorValidator.evWrongInputList,
6981 TosaErrorValidator.evWrongOutputList,
6982 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006983 },
6984 "tile": {
6985 "op": Op.TILE,
6986 "operands": (1, 0),
6987 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
6988 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006989 "error_if_validators": (
6990 TosaErrorValidator.evWrongInputType,
6991 TosaErrorValidator.evWrongOutputType,
6992 TosaErrorValidator.evWrongInputList,
6993 TosaErrorValidator.evWrongOutputList,
6994 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08006995 },
6996 "transpose": {
6997 "op": Op.TRANSPOSE,
6998 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01006999 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007000 "build_fcn": (
7001 build_transpose,
7002 TosaTensorGen.tgBasic,
7003 TosaArgGen.agTranspose,
7004 ),
7005 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007006 "error_if_validators": (
7007 TosaErrorValidator.evIndexOutsideBounds,
7008 TosaErrorValidator.evIndexUsedTwice,
7009 TosaErrorValidator.evWrongInputType,
7010 TosaErrorValidator.evWrongOutputType,
7011 TosaErrorValidator.evWrongInputList,
7012 TosaErrorValidator.evWrongOutputList,
7013 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007014 },
Jared Smolens573ecd42021-03-04 15:24:10 -08007015 # Data nodes
7016 "const": {
7017 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07007018 "operands": (0, 1),
7019 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08007020 "types": TYPE_FIB,
7021 },
Jared Smolens573ecd42021-03-04 15:24:10 -08007022 "identity": {
7023 "op": Op.IDENTITY,
7024 "operands": (1, 0),
7025 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
7026 "types": TYPE_FIB,
7027 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007028 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08007029 "gather": {
7030 "op": Op.GATHER,
7031 # Only specify 'values' tensor here. 'indices' is generated in op building stage
7032 "operands": (1, 0),
7033 "rank": (3, 3),
7034 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
7035 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007036 "error_if_validators": (
7037 TosaErrorValidator.evWrongInputType,
7038 TosaErrorValidator.evWrongOutputType,
7039 TosaErrorValidator.evWrongInputList,
7040 TosaErrorValidator.evWrongOutputList,
7041 TosaErrorValidator.evWrongRank,
7042 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007043 },
7044 "scatter": {
7045 "op": Op.SCATTER,
7046 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007047 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08007048 "operands": (2, 0),
7049 "rank": (3, 3),
7050 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
7051 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007052 "error_if_validators": (
7053 TosaErrorValidator.evWrongInputType,
7054 TosaErrorValidator.evWrongOutputType,
7055 TosaErrorValidator.evWrongInputList,
7056 TosaErrorValidator.evWrongOutputList,
7057 TosaErrorValidator.evWrongRank,
7058 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007059 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007060 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08007061 "resize": {
7062 "op": Op.RESIZE,
7063 "operands": (1, 0),
7064 "rank": (4, 4),
7065 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
7066 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007067 "invalid_test_validators": (
7068 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
7069 TosaInvalidValidator.ivBadStride,
7070 ),
7071 "error_if_validators": (
7072 TosaErrorValidator.evMaxDimExceeded,
7073 TosaErrorValidator.evStrideSmallerEqualZero,
7074 TosaErrorValidator.evStrideLargerDimension,
7075 TosaErrorValidator.evStrideLargerEqualMax,
7076 TosaErrorValidator.evOffsetSmallerEqualMin,
7077 TosaErrorValidator.evOffsetLargerEqualMax,
7078 TosaErrorValidator.evShiftNotZero,
7079 TosaErrorValidator.evShiftSmallerOne,
7080 TosaErrorValidator.evShiftLargerEleven,
7081 TosaErrorValidator.evWrongInputType,
7082 TosaErrorValidator.evWrongOutputType,
7083 TosaErrorValidator.evWrongRank,
7084 TosaErrorValidator.evWrongInputList,
7085 TosaErrorValidator.evWrongOutputList,
7086 TosaErrorValidator.evBatchMismatch,
7087 TosaErrorValidator.evChannelMismatch,
7088 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007089 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007090 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08007091 "cast": {
7092 "op": Op.CAST,
7093 "operands": (1, 0),
7094 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
7095 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007096 "error_if_validators": (
7097 TosaErrorValidator.evWrongInputType,
7098 TosaErrorValidator.evWrongOutputType,
7099 TosaErrorValidator.evWrongInputList,
7100 TosaErrorValidator.evWrongOutputList,
7101 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007102 },
7103 "rescale": {
7104 "op": Op.RESCALE,
7105 "operands": (1, 0),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007106 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007107 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01007108 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007109 "error_if_validators": (
7110 TosaErrorValidator.evInputZeroPointNotZero,
7111 TosaErrorValidator.evOutputZeroPointNotZero,
7112 TosaErrorValidator.evScaleTrue,
7113 TosaErrorValidator.evScaleNotTrue,
7114 TosaErrorValidator.evWrongInputType,
7115 TosaErrorValidator.evWrongOutputType,
7116 TosaErrorValidator.evWrongRank,
7117 TosaErrorValidator.evWrongInputList,
7118 TosaErrorValidator.evWrongOutputList,
7119 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007120 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007121 # Custom
7122 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08007123 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07007124 # Two varients of cond_if, one that generates one of two constant tensors (no
7125 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
7126 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08007127 "cond_if_const": {
7128 "op": Op.COND_IF,
7129 "operands": (0, 2),
7130 "build_fcn": (
7131 build_cond_if_const,
7132 TosaTensorGen.tgBasic,
7133 TosaArgGen.agCondIf,
7134 ),
7135 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007136 "error_if_validators": (
7137 TosaErrorValidator.evOutputListThenGraphMismatch,
7138 TosaErrorValidator.evOutputListElseGraphMismatch,
7139 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007140 },
7141 "cond_if_binary": {
7142 "op": Op.COND_IF,
7143 "operands": (2, 0),
7144 "build_fcn": (
7145 build_cond_if_binary,
7146 TosaTensorGen.tgBasic,
7147 TosaArgGen.agCondIf,
7148 ),
Les Bell6040b4d2021-10-11 12:50:31 +01007149 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007150 "error_if_validators": (
7151 TosaErrorValidator.evInputListThenGraphMismatch,
7152 TosaErrorValidator.evInputListElseGraphMismatch,
7153 TosaErrorValidator.evOutputListThenGraphMismatch,
7154 TosaErrorValidator.evOutputListElseGraphMismatch,
7155 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007156 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007157 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08007158 "while_loop": {
7159 "op": Op.WHILE_LOOP,
7160 "operands": (0, 1),
7161 "build_fcn": (
7162 build_while_loop,
7163 TosaTensorGen.tgBasic,
7164 TosaArgGen.agWhileLoop,
7165 ),
7166 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007167 "error_if_validators": (
7168 TosaErrorValidator.evInputListOutputListMismatch,
7169 TosaErrorValidator.evInputListCondGraphMismatch,
7170 TosaErrorValidator.evInputListBodyGraphInputMismatch,
7171 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
7172 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
7173 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08007174 },
Eric Kunzee5e26762020-10-13 16:11:07 -07007175 }
7176
Kevin Cheng550ccc52021-03-03 11:21:43 -08007177
Eric Kunzee5e26762020-10-13 16:11:07 -07007178class OutputShaper:
7179 # Methods in this class compute the expected output shape and datatype
7180 # for common classes of operations
7181 def __init__(self):
7182 pass
7183
7184 # These methods return arguments that can be used for
7185 # creating a new output tensor
7186 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01007187 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
7188 if error_name != ErrorIf.RankMismatch:
7189 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08007190 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007191
7192 shape = []
7193 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007194 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07007195 shape.append(b.shape[i])
7196 else:
7197 shape.append(a.shape[i])
7198
Matthew Haddoneacff9a2021-09-24 14:42:13 +01007199 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007200 all_dtypes = [
7201 DType.INT8,
7202 DType.INT16,
7203 DType.INT32,
7204 DType.INT48,
7205 DType.FLOAT,
7206 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01007207 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7208 outputDType = rng.choice(wrong_dtypes)
7209 else:
7210 outputDType = a.dtype
7211
7212 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007213
7214 @staticmethod
7215 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08007216 assert len(a.shape) == len(b.shape)
7217 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007218
7219 shape = []
7220 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08007221 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07007222 shape.append(a.shape[i])
7223
Kevin Cheng550ccc52021-03-03 11:21:43 -08007224 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007225
7226 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01007227 def unaryOp(ser, rng, a, error_name=None):
7228 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007229 all_dtypes = [
7230 DType.INT8,
7231 DType.INT16,
7232 DType.INT32,
7233 DType.INT48,
7234 DType.FLOAT,
7235 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01007236 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7237 outputDType = rng.choice(wrong_dtypes)
7238 else:
7239 outputDType = a.dtype
7240
7241 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007242
7243 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007244 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00007245 if error_name != ErrorIf.RankMismatch:
7246 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08007247 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007248
7249 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00007250 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007251 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00007252 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
7253 else:
7254 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07007255
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007256 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007257 all_dtypes = [
7258 DType.INT8,
7259 DType.INT16,
7260 DType.INT32,
7261 DType.INT48,
7262 DType.FLOAT,
7263 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007264 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7265 outputDType = rng.choice(wrong_dtypes)
7266 else:
7267 outputDType = a.dtype
7268
7269 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007270
7271 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007272 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00007273 if error_name != ErrorIf.RankMismatch:
7274 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08007275 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007276
7277 # Do broadcast
7278 shape = []
7279 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08007280 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07007281 shape.append(b.shape[i])
7282 else:
7283 shape.append(a.shape[i])
7284
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007285 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007286 wrong_dtypes = [
7287 DType.INT8,
7288 DType.INT16,
7289 DType.INT32,
7290 DType.INT48,
7291 DType.FLOAT,
7292 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007293 outputDType = rng.choice(wrong_dtypes)
7294 else:
7295 outputDType = DType.BOOL
7296
7297 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007298
7299 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01007300 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007301 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007302 if error_name not in [
7303 ErrorIf.AxisSmallerZero,
7304 ErrorIf.AxisLargerRank,
7305 ErrorIf.ShapeOfAxisNotOne,
7306 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01007307 shape[axis] = 1
7308 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
7309 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07007310
Matthew Haddond6ce7252021-09-29 15:35:44 +01007311 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007312 all_dtypes = [
7313 DType.INT8,
7314 DType.INT16,
7315 DType.INT32,
7316 DType.INT48,
7317 DType.FLOAT,
7318 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01007319 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7320 outputDType = rng.choice(wrong_dtypes)
7321 else:
7322 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07007323
Matthew Haddond6ce7252021-09-29 15:35:44 +01007324 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007325
7326 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007327 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007328 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007329
7330 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
7331 del shape[axis]
7332
7333 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
7334 remove = rng.choice([True, False])
7335 if remove and len(shape) > 1:
7336 del shape[0]
7337 else:
7338 shape.append(1)
7339 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
7340 for i in range(len(shape)):
7341 shape[i] = shape[i] + rng.integers(1, 10)
7342
7343 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007344 all_dtypes = [
7345 DType.INT8,
7346 DType.INT16,
7347 DType.INT32,
7348 DType.INT48,
7349 DType.FLOAT,
7350 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007351 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
7352 outputDType = rng.choice(wrong_dtypes)
7353 else:
7354 outputDType = DType.INT32
7355
7356 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007357
7358 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00007359 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007360
7361 # IFM: NHWC
7362 # Filter: OHWI
7363 # OFM: NHWC
7364
7365 if len(padding) == 2:
7366 # Expand padding to 4 parameters in the case of transpose_conv2d
7367 # From H,W to T,B,L,R
7368 padding = [padding[0], padding[0], padding[1], padding[1]]
7369
Kevin Cheng550ccc52021-03-03 11:21:43 -08007370 h = (
7371 ifm.shape[1]
7372 - filter.shape[1]
7373 - (filter.shape[1] - 1) * (dilations[0] - 1)
7374 + padding[0]
7375 + padding[1]
7376 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07007377
Kevin Cheng550ccc52021-03-03 11:21:43 -08007378 w = (
7379 ifm.shape[2]
7380 - filter.shape[2]
7381 - (filter.shape[2] - 1) * (dilations[1] - 1)
7382 + padding[2]
7383 + padding[3]
7384 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07007385
Les Bell0e027d42021-11-09 14:42:14 +00007386 # Avoid illegal dimensions, which can be generated in error_if tests
7387 h = max(h, 1)
7388 w = max(w, 1)
7389
Eric Kunzee5e26762020-10-13 16:11:07 -07007390 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
7391
Kevin Cheng3a478572021-01-22 17:21:02 -08007392 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007393 out_dtype = DType.INT32
7394 elif ifm.dtype == DType.INT16:
7395 out_dtype = DType.INT48
7396 elif ifm.dtype == DType.FLOAT:
7397 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00007398 elif error_name == ErrorIf.WrongInputType:
7399 # Pick some potentially correct output dtype if input type is incorrect
7400 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007401 else:
Les Bell0e027d42021-11-09 14:42:14 +00007402 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
7403
7404 if error_name == ErrorIf.WrongOutputType:
7405 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
7406 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07007407
Kevin Cheng550ccc52021-03-03 11:21:43 -08007408 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007409
7410 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00007411 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07007412
7413 # IFM: NDHWC
7414 # Filter: ODHWI
7415 # OFM: NDHWC
7416
7417 d = (
7418 ifm.shape[1]
7419 - filter.shape[1]
7420 - (filter.shape[1] - 1) * (dilations[0] - 1)
7421 + padding[0]
7422 + padding[1]
7423 ) // strides[0] + 1
7424
7425 h = (
7426 ifm.shape[2]
7427 - filter.shape[2]
7428 - (filter.shape[2] - 1) * (dilations[1] - 1)
7429 + padding[2]
7430 + padding[3]
7431 ) // strides[1] + 1
7432
7433 w = (
7434 ifm.shape[3]
7435 - filter.shape[3]
7436 - (filter.shape[3] - 1) * (dilations[2] - 1)
7437 + padding[4]
7438 + padding[5]
7439 ) // strides[2] + 1
7440
Les Bell0e027d42021-11-09 14:42:14 +00007441 # Avoid illegal dimensions, which can be generated in error_if tests
7442 d = max(d, 1)
7443 h = max(h, 1)
7444 w = max(w, 1)
7445
Kevin Cheng1533b852021-09-01 12:51:58 -07007446 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
7447
7448 if ifm.dtype == DType.INT8:
7449 out_dtype = DType.INT32
7450 elif ifm.dtype == DType.INT16:
7451 out_dtype = DType.INT48
7452 elif ifm.dtype == DType.FLOAT:
7453 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00007454 elif error_name == ErrorIf.WrongInputType:
7455 # Pick some potentially correct output dtype if input type is incorrect
7456 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07007457 else:
Les Bell0e027d42021-11-09 14:42:14 +00007458 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
7459
7460 if error_name == ErrorIf.WrongOutputType:
7461 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
7462 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07007463
7464 return ser.addOutput(ofm_shape, out_dtype)
7465
7466 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007467 def depthwiseConv2dOp(
7468 ser, rng, ifm, filter, strides, padding, dilations, error_name=None
7469 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07007470 # IFM: NHWC
7471 # Filter: HWCM
7472 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08007473 h = (
7474 ifm.shape[1]
7475 - filter.shape[0]
7476 - (filter.shape[0] - 1) * (dilations[0] - 1)
7477 + padding[0]
7478 + padding[1]
7479 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07007480
Kevin Cheng550ccc52021-03-03 11:21:43 -08007481 w = (
7482 ifm.shape[2]
7483 - filter.shape[1]
7484 - (filter.shape[1] - 1) * (dilations[1] - 1)
7485 + padding[2]
7486 + padding[3]
7487 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07007488
Les Bell0e027d42021-11-09 14:42:14 +00007489 # Avoid illegal dimensions, which can be generated in error_if tests
7490 h = max(h, 1)
7491 w = max(w, 1)
7492
Eric Kunzee5e26762020-10-13 16:11:07 -07007493 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
7494
Kevin Cheng3a478572021-01-22 17:21:02 -08007495 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007496 out_dtype = DType.INT32
7497 elif ifm.dtype == DType.INT16:
7498 out_dtype = DType.INT48
7499 elif ifm.dtype == DType.FLOAT:
7500 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00007501 elif error_name == ErrorIf.WrongInputType:
7502 # Pick some potentially correct output dtype if input type is incorrect
7503 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007504 else:
Les Bell0e027d42021-11-09 14:42:14 +00007505 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
7506
7507 if error_name == ErrorIf.WrongOutputType:
7508 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
7509 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07007510
Kevin Cheng550ccc52021-03-03 11:21:43 -08007511 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007512
7513 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007514 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007515 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007516 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007517 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007518 h = 1
7519 w = 1
7520 else:
7521 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
7522 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
7523
7524 if error_name == ErrorIf.PoolingOutputShapeMismatch:
7525 choices = [1, 2, 3, 4, 5]
7526 h = h + rng.choice(choices)
7527 w = w + rng.choice(choices)
Eric Kunzee5e26762020-10-13 16:11:07 -07007528
Eric Kunzee5e26762020-10-13 16:11:07 -07007529 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007530
7531 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007532 all_dtypes = [
7533 DType.INT8,
7534 DType.INT16,
7535 DType.INT32,
7536 DType.INT48,
7537 DType.FLOAT,
7538 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01007539 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
7540 outputDType = rng.choice(wrong_dtypes)
7541 else:
7542 outputDType = ifm.dtype
7543
7544 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007545
7546 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007547 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007548 # input: N, IC
7549 # filter: OC, IC
7550 # output: N, OC
7551
7552 output_shape = [input.shape[0], filter.shape[0]]
7553
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007554 if error_name == ErrorIf.WrongOutputType:
7555 if input.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007556 incorrect_types = (
7557 DType.INT4,
7558 DType.INT8,
7559 DType.INT16,
7560 DType.INT48,
7561 DType.FLOAT,
7562 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007563 elif input.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007564 incorrect_types = (
7565 DType.INT4,
7566 DType.INT8,
7567 DType.INT16,
7568 DType.INT32,
7569 DType.FLOAT,
7570 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007571 elif input.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007572 incorrect_types = (
7573 DType.INT4,
7574 DType.INT8,
7575 DType.INT16,
7576 DType.INT32,
7577 DType.INT48,
7578 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007579 out_dtype = rng.choice(a=incorrect_types)
7580 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007581 out_dtype = DType.INT32
7582 elif input.dtype == DType.INT16:
7583 out_dtype = DType.INT48
7584 elif input.dtype == DType.FLOAT:
7585 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007586 elif error_name == ErrorIf.WrongInputType:
7587 # Pick some potentially correct output dtype if input type is incorrect
7588 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007589 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08007590 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07007591
Kevin Cheng550ccc52021-03-03 11:21:43 -08007592 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007593
7594 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007595 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07007596 # a: N, H, C
7597 # b: N, C, W
7598 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07007599
Kevin Cheng2d60f002021-06-09 14:18:32 -07007600 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07007601
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007602 if error_name == ErrorIf.WrongOutputType:
7603 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007604 incorrect_types = (
7605 DType.INT4,
7606 DType.INT8,
7607 DType.INT16,
7608 DType.INT48,
7609 DType.FLOAT,
7610 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007611 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007612 incorrect_types = (
7613 DType.INT4,
7614 DType.INT8,
7615 DType.INT16,
7616 DType.INT32,
7617 DType.FLOAT,
7618 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007619 elif a.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007620 incorrect_types = (
7621 DType.INT4,
7622 DType.INT8,
7623 DType.INT16,
7624 DType.INT32,
7625 DType.INT48,
7626 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007627 out_dtype = rng.choice(a=incorrect_types)
7628 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007629 out_dtype = DType.INT32
7630 elif a.dtype == DType.INT16:
7631 out_dtype = DType.INT48
7632 elif a.dtype == DType.FLOAT:
7633 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007634 elif error_name == ErrorIf.WrongInputType:
7635 # Pick some potentially correct output dtype if input type is incorrect
7636 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007637 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01007638 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07007639
Kevin Cheng550ccc52021-03-03 11:21:43 -08007640 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007641
7642 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007643 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01007644 input1 = a[0]
7645 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07007646
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007647 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01007648 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007649 if not (
7650 # unable to concat tensors of different ranks
7651 error_name == ErrorIf.ConcatInputRankMismatch
7652 # unable to concat tensors along an invalid axis
7653 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007654 ):
7655 for tensor in remaining_inputs:
7656 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07007657
Matthew Haddon01c359d2021-10-15 16:30:48 +01007658 if error_name == ErrorIf.ConcatShapeSumMismatch:
7659 output_shape[axis] += rng.integers(5, 10)
7660
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007661 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007662 all_dtypes = {
7663 DType.INT8,
7664 DType.INT16,
7665 DType.INT32,
7666 DType.INT48,
7667 DType.FLOAT,
7668 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007669 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
7670 outputDType = rng.choice(wrong_dtypes)
7671 else:
7672 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01007673
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007674 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007675
7676 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01007677 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007678
7679 output_shape = a.shape.copy()
7680
7681 for i in range(len(output_shape)):
7682 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
7683
Matthew Haddone807aae2021-10-11 18:12:58 +01007684 # Fix negative output shape if error_if test causes it
7685 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
7686 output_shape = [i if i >= 1 else 1 for i in output_shape]
7687
7688 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007689 all_dtypes = [
7690 DType.INT8,
7691 DType.INT16,
7692 DType.INT32,
7693 DType.INT48,
7694 DType.FLOAT,
7695 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01007696 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7697 outputDType = rng.choice(wrong_dtypes)
7698 else:
7699 outputDType = a.dtype
7700
7701 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007702
7703 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01007704 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007705 output_shape = shape.copy()
7706
7707 totalElements = 1
7708 for i in a.shape:
7709 totalElements *= i
7710
7711 # If there are any -1 elements, figure out what that dimension must be
7712 totalOutputElements = 1
7713 for i in output_shape:
7714 if i != -1:
7715 totalOutputElements *= i
7716
7717 # And fill it in
7718 for i in range(len(output_shape)):
7719 if output_shape[i] == -1:
7720 output_shape[i] = totalElements // totalOutputElements
7721
Matthew Haddone807aae2021-10-11 18:12:58 +01007722 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
7723 for i in range(len(output_shape)):
7724 output_shape[i] = output_shape[i] + rng.integers(1, 10)
7725
7726 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007727 all_dtypes = [
7728 DType.INT8,
7729 DType.INT16,
7730 DType.INT32,
7731 DType.INT48,
7732 DType.FLOAT,
7733 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01007734 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7735 outputDType = rng.choice(wrong_dtypes)
7736 else:
7737 outputDType = a.dtype
7738
7739 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007740
7741 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01007742 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007743
Matthew Haddone807aae2021-10-11 18:12:58 +01007744 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007745 all_dtypes = [
7746 DType.INT8,
7747 DType.INT16,
7748 DType.INT32,
7749 DType.INT48,
7750 DType.FLOAT,
7751 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01007752 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7753 outputDType = rng.choice(wrong_dtypes)
7754 else:
7755 outputDType = a.dtype
7756
7757 if error_name == ErrorIf.SizeOutputShapeMismatch:
7758 output_shape = size.copy()
7759 for index in range(len(output_shape)):
7760 if output_shape[index] <= 2:
7761 output_shape[index] = output_shape[index] + rng.choice([1, 2])
7762 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007763 output_shape[index] = output_shape[index] + rng.choice(
7764 [-2, -1, 1, 2]
7765 )
Matthew Haddone807aae2021-10-11 18:12:58 +01007766 else:
7767 output_shape = size.copy()
7768
7769 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007770
7771 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007772 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007773
7774 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08007775 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07007776
7777 for i in range(len(output_shape)):
7778 output_shape[i] = a.shape[i] * multiples[i]
7779
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007780 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007781 all_dtypes = [
7782 DType.INT8,
7783 DType.INT16,
7784 DType.INT32,
7785 DType.INT48,
7786 DType.FLOAT,
7787 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007788 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7789 outputDType = rng.choice(wrong_dtypes)
7790 else:
7791 outputDType = a.dtype
7792
7793 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007794
7795 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01007796 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07007797 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01007798
Kevin Cheng550ccc52021-03-03 11:21:43 -08007799 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07007800
Matthew Haddone807aae2021-10-11 18:12:58 +01007801 if error_name == ErrorIf.IndexOutsideBounds:
7802 for i in range(len(output_shape)):
7803 output_shape[i] = a.shape[0]
7804 else:
7805 for i in range(len(output_shape)):
7806 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07007807
Matthew Haddone807aae2021-10-11 18:12:58 +01007808 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007809 all_dtypes = [
7810 DType.INT8,
7811 DType.INT16,
7812 DType.INT32,
7813 DType.INT48,
7814 DType.FLOAT,
7815 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01007816 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
7817 outputDType = rng.choice(wrong_dtypes)
7818 else:
7819 outputDType = a.dtype
7820
7821 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007822
7823 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007824 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00007825 if error_name != ErrorIf.WrongRank:
7826 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08007827 assert len(indices.shape) == 2
7828 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07007829
Kevin Cheng77d0f762020-11-24 10:26:32 -08007830 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
7831
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007832 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007833 all_dtypes = [
7834 DType.INT8,
7835 DType.INT16,
7836 DType.INT32,
7837 DType.INT48,
7838 DType.FLOAT,
7839 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007840 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
7841 outputDType = rng.choice(wrong_dtypes)
7842 else:
7843 outputDType = values.dtype
7844
7845 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08007846
7847 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007848 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00007849 if error_name != ErrorIf.WrongRank:
7850 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08007851 assert len(indices.shape) == 2
7852 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08007853 assert values_in.shape[0] == indices.shape[0] # N
7854 assert input.shape[1] == indices.shape[1] # W
7855 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08007856
7857 output_shape = values_in.shape
7858
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007859 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007860 all_dtypes = [
7861 DType.INT8,
7862 DType.INT16,
7863 DType.INT32,
7864 DType.INT48,
7865 DType.FLOAT,
7866 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007867 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
7868 outputDType = rng.choice(wrong_dtypes)
7869 else:
7870 outputDType = values_in.dtype
7871
7872 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07007873
7874 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007875 def tableOp(ser, rng, input, error_name=None):
7876 # Same shape as the input, dtype dependent on input dtype
7877 if error_name != ErrorIf.WrongInputType:
7878 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00007879 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007880 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007881 wrong_dtypes = [
7882 DType.INT8,
7883 DType.INT16,
7884 DType.INT32,
7885 DType.INT48,
7886 DType.FLOAT,
7887 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007888 wrong_dtypes.remove(output_dtype)
7889 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01007890 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007891
7892 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08007893 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007894 serializer,
7895 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08007896 input,
7897 mode,
7898 stride,
7899 offset,
7900 shift,
7901 stride_fp,
7902 offset_fp,
7903 output_dims,
7904 input_dtype,
7905 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007906 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08007907 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01007908 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007909 output_dims = [
7910 input.shape[0],
7911 output_dims[0],
7912 output_dims[0],
7913 input.shape[0],
7914 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01007915 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007916 if error_name == ErrorIf.BatchMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007917 output_dims = [
7918 input.shape[0] + rng.integers(1, 10),
7919 output_dims[0],
7920 output_dims[1],
7921 input.shape[3],
7922 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007923 elif error_name == ErrorIf.ChannelMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007924 output_dims = [
7925 input.shape[0],
7926 output_dims[0],
7927 output_dims[1],
7928 input.shape[3] + rng.integers(1, 10),
7929 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007930 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00007931 output_dims = [
7932 input.shape[0],
7933 output_dims[0],
7934 output_dims[1],
7935 input.shape[3],
7936 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07007937
Matthew Haddon693ba9e2021-09-22 11:24:37 +01007938 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007939
7940 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01007941 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08007942 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07007943
7944 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00007945 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Kevin Cheng3a478572021-01-22 17:21:02 -08007946 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07007947 out_dtype = DType.INT32
7948 elif ifm.dtype == DType.INT16:
7949 out_dtype = DType.INT48
7950 elif ifm.dtype == DType.FLOAT:
7951 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00007952 elif error_name == ErrorIf.WrongInputType:
7953 # Pick some potentially correct output dtype if input type is incorrect
7954 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07007955 else:
Les Bell0e027d42021-11-09 14:42:14 +00007956 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
7957
7958 if error_name == ErrorIf.WrongOutputType:
7959 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
7960 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07007961
Kevin Cheng550ccc52021-03-03 11:21:43 -08007962 return ser.addOutput(output_shape, out_dtype)