blob: 41ef4dffbad5851850f0ca40f15e92ccb905592a [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
262 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100263 shape_list = []
264
265 # Choose one of the inputs to broadcast
266 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100267 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
268 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000269
Jeremy Johnson0a042992024-02-28 13:20:05 +0000270 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100271 shape_bcast = shape.copy()
272
Jerry Ge135c9552023-05-23 20:59:32 +0000273 # To test broadcasting, the chosen fuzz index dimension should not be 1
274 if shape_bcast[fuzz_idx] == 1:
275 shape_bcast[fuzz_idx] += 1
276
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100277 # If the chosen input, pick a random index to broadcast
278 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000279 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100280 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100281 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100282 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100283 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 )
285 if rank != 1:
286 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100287 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000289 elif error_name == ErrorIf.BroadcastShapesMismatch:
290 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 else:
292 shape_bcast[fuzz_idx] = 1
293
294 shape_list.append(shape_bcast)
295
296 return shape_list
297
298 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100299 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000300 pl, const = op["operands"]
301 num_shapes = pl + const
302 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100303 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000304 )
305
306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100307 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000308 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 shape_list = TosaTensorGen._get_broadcast_shapes(
310 testGen, rng, 2, rank, error_name
311 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000312 # Add a single dimension tensor for shift
313 shape_list.append([1])
314 return shape_list
315
316 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100317 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318 pl, const = op["operands"]
319
320 if error_name != ErrorIf.WrongRank:
321 assert rank == 4
322
323 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000325 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326
327 # Constrict the overall size of the shape when creating ERROR_IF tests
328 if error_name:
329 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
330 ifm_shape, max_dim=24, max_items=10000
331 )
332
333 # Get the filter height/width from the operator parameters
334 filter_hw = op["filter"]
335
336 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100337 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # The filter dimensions are OHWI
340 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
341
342 # The bias is OC
343 bias_shape = np.asarray([ofm_depth])
344
345 return [ifm_shape, filter_shape, bias_shape]
346
347 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100348 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 pl, const = op["operands"]
350
351 if error_name != ErrorIf.WrongRank:
352 assert rank == 5
353
354 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100355 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000356 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100357
358 # Constrict the overall size of the shape when creating ERROR_IF tests
359 if error_name:
360 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
361 ifm_shape, max_dim=24, max_items=10000
362 )
363
364 # Get the filter depth/height/width from the operator parameters
365 filter_dhw = op["filter"]
366
367 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100368 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100369
370 # The filter dimensions are ODHWI
371 filter_shape = np.asarray(
372 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
373 )
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_channel])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100381 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386
387 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100388 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000389 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100390
391 # Constrict the overall size of the shape when creating ERROR_IF tests
392 if error_name:
393 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
394 ifm_shape, max_dim=24, max_items=10000
395 )
396
397 # Get the filter height/width from the operator parameters
398 filter_hw = op["filter"]
399
400 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100401 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100402
403 # The filter dimensions are OHWI
404 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
405
406 # The bias is OC
407 bias_shape = np.asarray([ofm_depth])
408
409 return [ifm_shape, filter_shape, bias_shape]
410
411 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100412 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100413 pl, const = op["operands"]
414
415 if error_name != ErrorIf.WrongRank:
416 assert rank == 4
417 assert pl == 1 and const == 2
418
419 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100420 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000421 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
426 ifm_shape, max_dim=24, max_items=10000
427 )
428
429 # Get the filter height/width from the operator parameters
430 # Filter is KH, HW, C, M
431 filter_hw = op["filter"]
432
433 # Generate a random OFM depth, but don't let it get too big because
434 # the output depth is M * C
435 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100436 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100437 ) + 1
438
439 # The filter dimensions are HWCM
440 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
441
442 # The bias is M * C
443 bias_shape = np.asarray([ifm_shape[3] * filter_m])
444
445 return [ifm_shape, filter_shape, bias_shape]
446
447 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100448 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 2 and const == 0
454
455 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100456 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 inc_h = 2 if ifm_shape[1] == 1 else 1
469 inc_w = 2 if ifm_shape[2] == 1 else 1
470 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100471 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000472 ifm_shape[1] += selected_inc[0]
473 ifm_shape[2] += selected_inc[1]
474
475 ifm_shape = testGen.constrictBatchSize(ifm_shape)
476
477 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
478 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100479 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000480 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100481 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000482 ifm_shapes[modify_shape][modify_dim] *= 2
483
484 return [ifm_shapes[0], ifm_shapes[1]]
485
486 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100487 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000488 pl, const = op["operands"]
489
490 if error_name != ErrorIf.WrongRank:
491 assert rank == 3
492 assert pl == 1 and const == 0
493
494 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100495 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000496
497 # Select nearest lower power of two from input height and width
498 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
499 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
500
501 # Constrict the overall size of the shape when creating ERROR_IF tests
502 if error_name:
503 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
504
505 # Generate an invalid kernel that is not a power of two
506 if error_name == ErrorIf.KernelNotPowerOfTwo:
507 # We must increment by 2 if current size is 1
508 inc_h = 2 if ifm_shape[1] == 1 else 1
509 inc_w = 2 if ifm_shape[2] == 1 else 1
510 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100511 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000512 ifm_shape[1] += selected_inc[0]
513 ifm_shape[2] += selected_inc[1]
514
James Ward30124a82023-02-02 14:56:33 +0000515 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000516
517 return [ifm_shape]
518
519 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100520 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100521 pl, const = op["operands"]
522
523 if error_name != ErrorIf.WrongRank:
524 assert rank == 2
525
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527
528 # Constrict the overall size of the shape when creating ERROR_IF tests
529 if error_name:
530 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
531
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100532 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100533 low=testGen.args.tensor_shape_range[0],
534 high=testGen.args.tensor_shape_range[1],
535 size=1,
536 )[0]
537 filter_shape = np.asarray([filter_oc, input_shape[1]])
538
539 bias_shape = np.asarray([filter_oc])
540
541 return [input_shape, filter_shape, bias_shape]
542
543 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100544 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100545 pl, const = op["operands"]
546
547 if error_name != ErrorIf.WrongRank:
548 assert rank == 3
549 assert pl == 2 and const == 0
550
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100551 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100552
553 # Constrict the overall size of the shape when creating ERROR_IF tests
554 if error_name:
555 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
556
557 # Get a random number for b_oc even if target shape is defined
558 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100559 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100560 low=testGen.args.tensor_shape_range[0],
561 high=testGen.args.tensor_shape_range[1],
562 size=1,
563 )
564 )[0]
565 # If N or H is large let b_oc be 1 to reduce output tensor size
566 if max(a_shape) > 1000:
567 b_oc = 1
568
569 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
570 return [a_shape, b_shape]
571
572 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100573 def tgConcat(testGen, rng, op, rank, error_name=None):
574 pl, const = op["operands"]
575 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100576
577 # Create extra tensors to concat.
578 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100579 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100580 shape_list = []
581 for i in range(pl + const + num_tensors):
582 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100583 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100584 wrongShape = shape.copy()
585
586 if remove and len(shape) > 1:
587 wrongShape = wrongShape[1:]
588 else:
589 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100590 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100591
592 shape_list.append(wrongShape)
593 else:
594 shape_list.append(shape.copy())
595
596 return shape_list
597
598 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100599 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100600 if error_name in [
601 ErrorIf.AxisSmallerZero,
602 ErrorIf.AxisLargerRank,
603 ErrorIf.ConcatInputRankMismatch,
604 ]:
605 return shapeList
606
607 # Split concat shape along axis to allow for multiple const inputs
608 # without making too many large tensors
609 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
610 # If axis can't be split we still need to invalidate other dimensions
611 if error_name == ErrorIf.ConcatInputDimMismatch:
612 for shape in shapeList[1:]:
613 # Negative test shapeLists are created individually for each test,
614 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100615 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100616 return shapeList
617
618 # Create copy of shape we are going to split (so we don't alter shapeList)
619 shape = shapeList[0].copy()
620 # Add original shape as first input
621 new_shapeList = [shape.copy()]
622 length_on_axis = shape[axis]
623 remaining_length = length_on_axis
624 for i in range(len(shapeList) - 2):
625 # Calculate split on axis and remaining value
626 split_shape_val = int(shape[axis] / 2)
627 remaining_length = remaining_length - split_shape_val
628
629 # Append new shape, and set remaining shape
630 shape[axis] = split_shape_val
631 new_shapeList.append(shape.copy())
632
633 # invalidate dimensions
634 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100635 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100636 else:
637 shape[axis] = remaining_length
638
639 if i == len(shapeList) - 3:
640 new_shapeList.append(shape.copy())
641
642 return new_shapeList
643
644
645class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100646 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100647
648 def __init__(self):
649 pass
650
Jeremy Johnson1271c442023-09-05 11:39:26 +0100651 class TVGInfo:
652 """Enhanced tensor values information including data gen dict."""
653
654 def __init__(self, tensorList, dataGenDict):
655 self.tensorList = tensorList
656 self.dataGenDict = dataGenDict
657
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100658 # Default high value for random numbers
659 TVG_FLOAT_HIGH_VALUE = {
660 DType.FP32: (1 << 128) - (1 << (127 - 23)),
661 DType.FP16: (1 << 16) - (1 << (15 - 10)),
662 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000663 DType.FP8E4M3: 448,
664 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100665 }
666
Jeremy Johnson30476252023-11-20 16:15:30 +0000667 # Default lowest normal values for random numbers
668 TVG_FLOAT_LOW_VALUE = {
669 DType.FP32: np.exp2(-126),
670 DType.FP16: np.exp2(-14),
671 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000672 DType.FP8E4M3: np.exp2(-9),
673 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000674 }
675
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100676 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100677 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # Return a tuple of (low,high) data range values for the given data
679 # type using a combination of per operator table limits, data limits
680 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000681 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100682 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000683 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000684 if lowValueLookup is not None and dtype in lowValueLookup:
685 low_val = lowValueLookup[dtype]
686 else:
687 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000688 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000689 # respecting the default ranges if more/less than the low/high
690 # values
691 data_range = (
692 max(low_val, type_range[0]),
693 min(high_val, type_range[1]),
694 )
695 if data_range[0] > data_range[1]:
696 # Invalid data range from low to high created due to user
697 # constraints revert to using internal ranges as they are
698 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000699 logger.info(
700 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
701 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000702 data_range = (low_val, high_val)
703 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000704 return None
705
706 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100707 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100708 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709 ):
710 # Variable inputs versus constants
711 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000712 if "p_count" in argsDict:
713 # Override for operators like CONCAT
714 pCount = argsDict["p_count"]
715 cCount = argsDict["c_count"]
716 assert pCount + cCount == len(
717 shapeList
718 ), "Placeholders & Constant tensors must match shapes list"
719
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000720 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100721
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100722 if (
723 error_name is not None
724 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100725 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100726 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000727 # Fall back to internal data gen when dealing with unsupported types or ops
728 data_range = argsDict["data_range"] if "data_range" in argsDict else None
729 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000730 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000731 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000732 if "data_range_list" in argsDict:
733 data_range = argsDict["data_range_list"][idx]["range"]
734 roundMode = (
735 "round" in argsDict["data_range_list"][idx]
736 and argsDict["data_range_list"][idx]["round"] is True
737 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000738 if data_range is not None and dtype not in (
739 DType.FP16,
740 DType.FP32,
741 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000742 DType.FP8E4M3,
743 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000744 ):
745 # Change from inclusive to exclusive range
746 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000747
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100748 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000749 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000750 if dtype == DType.SHAPE:
751 arr = np.int64(argsDict["fixed_data"][idx])
752 elif dtype == DType.INT8:
753 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000754 elif dtype == DType.INT16:
755 arr = np.int16(argsDict["fixed_data"][idx])
756 elif dtype == DType.INT32:
757 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000758 else:
759 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000760 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100761 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000762 if roundMode:
763 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000764 if idx < pCount:
765 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
766 else:
767 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100768
Jeremy Johnson1271c442023-09-05 11:39:26 +0100769 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
770
771 # Create data generator meta-data
772 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100773 tens_data = {
774 "version": "0.1",
775 "tensors": {},
776 }
777 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100778 for idx, shape in enumerate(shapeList):
779
780 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000781 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
782 tens_meta["generator"] = gtu.DataGenType(
783 gtu.DataGenType.FIXED_DATA
784 ).name
785 else:
786 tens_meta["generator"] = gtu.DataGenType(dg_type).name
787
Jeremy Johnson1271c442023-09-05 11:39:26 +0100788 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
789 tens_meta["shape"] = [int(i) for i in shape]
790 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100791 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100792
793 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100794 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100795 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100796 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100797
798 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
799 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000800 if (
801 tens_meta["generator"]
802 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
803 ):
804 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
805 tens_meta["fixed_data_info"] = info
806 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100807 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000808
Won Jeon64e4bfe2024-01-18 06:31:55 +0000809 data_range = None
810 if "data_range_list" in argsDict:
811 data_range = argsDict["data_range_list"][idx]["range"]
812 if "round" in argsDict["data_range_list"][idx]:
813 info["round"] = argsDict["data_range_list"][idx]["round"]
814 elif "data_range" in argsDict:
815 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000816
Won Jeon64e4bfe2024-01-18 06:31:55 +0000817 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100818 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000819 info["range"] = [str(v) for v in data_range]
820 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100821 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
822 info = {}
823 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100824 info["ks"] = int(argsDict["ks"])
825 if "acc_type" in argsDict:
826 # Convert type number into JSON name
827 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
828 "json"
829 ]
830 if "kernel" in argsDict:
831 info["kernel"] = [int(k) for k in argsDict["kernel"]]
832 if "axis" in argsDict:
833 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100834 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000835 elif dg_type == gtu.DataGenType.FULL_RANGE:
836 info = {}
837 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100838 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000839 )
840 tens_meta["full_range_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100841 else:
842 # TODO - other data gen type
843 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100844
845 # Using the finished generate config meta data - generate the data if
846 # needed and assign a tensor name from the serializer
847
848 # Need to generate data when not lazy or for the bias tensor as we need
849 # to work out if the bias data is non-zero for compliance
850 if not testGen.args.lazy_data_gen or (
851 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
852 ):
853 # Give this tensor a temporary name until we get one from the serializer
854 temp_name = f"placeholder_{idx}"
855 dg_tens_meta[temp_name] = tens_meta
856 # Create data now using the temporary name to access meta details
857 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000858 if tens_meta["data_type"] == "SHAPE":
859 # Tensor type SHAPE and Numpy file type must be the same
860 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100861 # Remove the item as we will give it the correct name later
862 del dg_tens_meta[temp_name]
863
864 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
865 # The KS value used by compliance verification is altered when the
866 # bias data is non-zero
867 if max(abs(data)) > 0.0:
868 argsDict["ksb"] = argsDict["ks"] + 1
869
870 if testGen.args.lazy_data_gen:
871 data = None
872
873 if tens_meta["input_type"] == "VARIABLE":
874 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
875 else:
876 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
877
878 tens_ser_list.append(tens)
879 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100880 dg_tens_meta[tens.name] = tens_meta
881
Jeremy Johnson1271c442023-09-05 11:39:26 +0100882 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
883
884 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100885 def tvgNegate(
886 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
887 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100888 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000889 # Integer test
890 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100891 pCount, cCount = op["operands"]
892 assert (
893 pCount == 1 and cCount == 0
894 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100895 # Must create tensors with values within accumulator (int32) negatable
896 # range
897 max_val = (1 << 31) - 1
898 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100899 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100900 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000902 tens_ser_list = []
903 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
905 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000906 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000908 # ERROR_IF or floating point test
909 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100910 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100911 )
912
Jeremy Johnson30476252023-11-20 16:15:30 +0000913 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000914 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
915 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
916 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
917 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
918 }
919
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100921 def tvgAddSub(
922 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
923 ):
Won Jeon74342e52024-01-09 00:34:40 +0000924 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000925 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000927 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100928 pCount, cCount = op["operands"]
929 assert (
930 pCount == 2 and cCount == 0
931 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000932 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000933 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
Jeremy Johnson32bf9012024-03-20 16:32:23 +0000934 data_range = None # Use default
935 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE):
936 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100937 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
938 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100939 if add:
940 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
941 else:
942 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
943
944 # Work out the saturation limits
945 max_i32 = (1 << 31) - 1
946 min_i32 = -(1 << 31)
947 max_arr = np.full(shapeList[1], max_i32)
948 min_arr = np.full(shapeList[1], min_i32)
949
950 # Find how much values exceed the maximum/minimums
951 sat_max_arr = np.maximum(res_arr - max_arr, 0)
952 sat_min_arr = np.minimum(res_arr - min_arr, 0)
953
954 if not add:
955 # Swap saturation values and negate values as we need to perform opposite operations
956 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
957
958 # Create new array of unsaturated values by clipping values as needed
959 b_unsat_arr = b_arr
960 if (sat_max_arr != 0).any():
961 # Clip values that cause saturation
962 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
963 # Reduce axes in unsaturated tensor to match original tensor
964 for axis, dim in enumerate(b_arr.shape):
965 if dim != b_unsat_arr.shape[axis]:
966 assert (
967 dim == 1
968 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
969 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
970
971 if (sat_min_arr != 0).any():
972 # Clip values that cause saturation
973 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
974 # Reduce axes in unsaturated tensor to match original tensor
975 for axis, dim in enumerate(b_arr.shape):
976 if dim != b_unsat_arr.shape[axis]:
977 assert (
978 dim == 1
979 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
980 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
981
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000982 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100983 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
984 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000985 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100986 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
987 )
988
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000989 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100990 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000991 # ERROR_IF or floating point test
992 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100993 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000994 )
995 if data_range:
996 argsDict["data_range"] = data_range
997
998 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100999 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001000 )
1001
1002 @staticmethod
1003 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001004 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001005 ):
1006 if dtypeList[0] in (
1007 DType.INT32,
1008 DType.INT16,
1009 DType.INT8,
1010 ):
1011 # Limit input tensors with cond_if_binary or while_loop to stop
1012 # saturation of add/sub ops with int32 and keep all logical shift
1013 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001014 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001015 pCount, cCount = op["operands"]
1016 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001017 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001018 for idx, shape in enumerate(shapeList[:]):
1019 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001020 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001022 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001023 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001024 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001025 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1026 )
1027 pRemain -= 1
1028 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001029 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001030 testGen.ser.addConst(shape, dtypeList[idx], arr)
1031 )
1032
Jeremy Johnson587cc842024-02-08 11:45:44 +00001033 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001034 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001035 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001036 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001037 )
1038
1039 @staticmethod
1040 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001041 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001042 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001043 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044 pCount, cCount = op["operands"]
1045 # Force value of operand[1] to be within [0, num_bits]
1046 assert (
1047 pCount == 2 and cCount == 0
1048 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1049
Jeremy Johnson587cc842024-02-08 11:45:44 +00001050 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001051 for idx, shape in enumerate(shapeList[:]):
1052 if idx == 1:
1053 if dtypeList[idx] == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001054 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001056 arr = np.int32(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001058 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001060 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001061 else:
1062 raise Exception("OpArithmeticRightShift: invalid input dtype")
1063 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001064 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001065 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001066
Jeremy Johnson587cc842024-02-08 11:45:44 +00001067 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001068
1069 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001070 def tvgReshape(
1071 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1072 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001073 dtypeList[1] = DType.SHAPE
1074 shapeList[1] = [len(argsDict["new_shape"])]
1075 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1076 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1077
1078 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001079 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001080 )
1081
1082 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001083 def tvgRescale(
1084 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1085 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001086 scale32 = argsDict["scale"]
1087 multiplier_arr = argsDict["multiplier"]
1088 shift_arr = argsDict["shift"]
1089
1090 if scale32:
1091 dtypeList[1] = DType.INT32
1092 else:
1093 dtypeList[1] = DType.INT16
1094 shapeList[1] = [len(multiplier_arr)]
1095 dtypeList[2] = DType.INT8
1096 shapeList[2] = [len(shift_arr)]
1097 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1098 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1099
1100 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001101 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001102 )
1103
1104 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001105 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001106 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1107 pad_values = argsDict["pad"].flatten()
1108 dtypeList[1] = DType.SHAPE
1109 shapeList[1] = [len(pad_values)]
1110 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1111 argsDict["fixed_data"] = [None, pad_values]
1112
1113 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001114 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001115 )
1116
1117 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001118 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001119 dtypeList[1] = DType.SHAPE
1120 shapeList[1] = [len(argsDict["start"])]
1121 dtypeList[2] = DType.SHAPE
1122 shapeList[2] = [len(argsDict["size"])]
1123 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1124 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1125
1126 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001127 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001128 )
1129
1130 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001131 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001132 dtypeList[1] = DType.SHAPE
1133 shapeList[1] = [len(argsDict["multiples"])]
1134 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1135
1136 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001137 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001138 )
1139
1140 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001141 def tvgSelect(
1142 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1143 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001144 # Set datatype of condition tensor to boolean
1145 dtypeList[0] = DType.BOOL
1146
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001147 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001148 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001149 )
1150
1151 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001152 def tvgIntDiv(
1153 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1154 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001156 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001157 pCount, cCount = op["operands"]
1158 assert (
1159 pCount == 2 and cCount == 0
1160 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1161
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001162 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001163
1164 # Two invalid cases for Op.INTDIV:
1165 # 1. divisor == 0
1166 # 2. dividend == -(1<<31) and divisor == -1
1167 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001168 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1169 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001170
1171 if (divisor_arr == 0).any():
1172 continue
1173
1174 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1175 continue
1176
1177 break
1178
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001179 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001180 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1181 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001182 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001183 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1184 )
1185
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001186 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001187 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001188 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001189 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001190 )
1191
Jeremy Johnson30476252023-11-20 16:15:30 +00001192 # Set the MUL data range to the square root of the largest value
1193 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001194 TVG_FLOAT_HIGH_VALUE_MUL = {
1195 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1196 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1197 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1198 }
1199
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001200 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001201 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001202 if error_name is not None or dtypeList[0] in (
1203 DType.FP16,
1204 DType.BF16,
1205 DType.FP32,
1206 ):
1207 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001208 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001209 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001210 )
1211 if data_range:
1212 argsDict["data_range"] = data_range
1213
Jeremy Johnson0a042992024-02-28 13:20:05 +00001214 if dtypeList[0] != DType.SHAPE:
1215 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1216 dtypeList[2] = DType.INT8
1217 shapeList[2] = [1]
1218 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1219 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1220
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001221 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001222 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001223 )
1224 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001225 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001226 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001227
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001228 tens_ser_list = []
1229
1230 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001231 if dtypeList[0] == DType.SHAPE:
1232 shift = 0
1233 else:
1234 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001235 if dtypeList[0] == DType.INT8:
1236 num_bits = 8
1237 elif dtypeList[0] == DType.INT16:
1238 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001239 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001240 num_bits = 32
1241 elif error_name == ErrorIf.WrongInputType:
1242 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001243 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001244 raise Exception(
1245 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1246 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001247
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001248 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001249 if dtypeList[idx] == DType.SHAPE:
1250 low = testGen.args.tensor_shape_range[0]
1251 high = testGen.args.tensor_shape_range[1]
1252 else:
1253 low = -(2 ** (num_bits - 1))
1254 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001255
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001256 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1257 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001258
1259 i = 0
1260 while True:
1261
1262 a_arr_64 = a_arr.astype(np.int64)
1263 b_arr_64 = b_arr.astype(np.int64)
1264
1265 if shift > 0:
1266 rounding = 1 << (shift - 1)
1267 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001268 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001269 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001270
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001271 if (result_arr > -(2**31)).all() and (
1272 result_arr <= ((2**31) - 1)
1273 ).all():
1274 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001275
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001276 i = i + 1
1277 a_arr = a_arr // 2
1278 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001279
Won Jeon74342e52024-01-09 00:34:40 +00001280 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001281 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001282 tens_ser_list.append(
1283 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1284 )
1285 tens_ser_list.append(
1286 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1287 )
1288 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001289 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001290 tens_ser_list.append(
1291 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1292 )
1293 tens_ser_list.append(
1294 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1295 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001296 tens_ser_list.append(
1297 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1298 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001299
1300 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001301
1302 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001303 def tvgConcat(
1304 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1305 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001306 count = len(shapeList) - testGen.args.num_const_inputs_concat
1307 if count < 1:
1308 count = 1
1309 if testGen.args.num_const_inputs_concat == 0:
1310 count = len(shapeList)
1311
Won Jeon74342e52024-01-09 00:34:40 +00001312 op = testGen.TOSA_OP_LIST[opName]
1313 if op["op"] == Op.CONCAT_SHAPE:
1314 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001315 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001316 else:
1317 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001318 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001319 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001320
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001321 # Override default pCount/cCount for operator
1322 argsDict["p_count"] = count
1323 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001324
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001325 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001326 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001327 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001328
1329 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001330 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001331 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001332 ):
1333 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001334 pCount, cCount = op["operands"]
1335 assert (
1336 pCount == 2 and cCount == 0
1337 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001338 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1339 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001340 tens_ser_list = []
1341 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001342 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1343 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001344 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001345 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1346 )
1347
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001348 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001349
1350 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001351 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001352 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1353 # Integer
1354 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001355 pCount, cCount = op["operands"]
1356 assert (
1357 pCount == 2 and cCount == 0
1358 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001359
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001360 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1361 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001362
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001363 # Using random numbers means that it will be very unlikely that
1364 # there are any matching (equal) values, therefore force that
1365 # there are twice the number of matching values as the tensor rank
1366 for num in range(0, len(shapeList[0]) * 2):
1367 a_index = []
1368 b_index = []
1369 # Choose an index in each axis for the whole shape
1370 for axis in range(0, len(shapeList[0])):
1371 # Index can be up to the largest dimension in both shapes
1372 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001373 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001374 )
1375 # Reduce the index down to a shape's dim for broadcasting
1376 a_index.append(min(shapeList[0][axis] - 1, index))
1377 b_index.append(min(shapeList[1][axis] - 1, index))
1378
1379 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1380
Jeremy Johnsona0150012023-11-15 15:52:06 +00001381 tens_ser_list = []
1382 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001383 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1384 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001385 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001386 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1387 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001388 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001389 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001390 # ERROR_IF or floating point test
1391 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001392 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001393 )
1394
1395 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001396 def tvgReduceSum(
1397 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1398 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001399 dtype = dtypeList[0]
1400 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001401 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001402 pCount, cCount = op["operands"]
1403 assert (
1404 pCount == 1 and cCount == 0
1405 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1406 # Limit values so that the sum cannot exceed the range of an int32 during
1407 # summation of any axis
1408 range_val = int((1 << 31) / max(shapeList[0]))
1409 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001410 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001411 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001412 tens_ser_list = []
1413 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001414 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001415 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001417 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001418 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001419 if (
1420 error_name is None
1421 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1422 ):
1423 # Limit ranges for (non error & non compliance) tests by using
1424 # values that can be summed on any axis to not hit infinity
1425 highval_lookup = {
1426 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1427 / max(shapeList[0])
1428 }
1429 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001430 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001431 )
1432 assert data_range is not None
1433 argsDict["data_range"] = data_range
1434
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001435 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001436 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001437 )
1438
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001439 @staticmethod
1440 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001441 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001442 ):
1443 dtype = dtypeList[0]
1444 if error_name is None:
1445 # Limit ranges for (non error) tests by using
1446 # values that can be multiplied on any axis to not hit infinity
1447 highval_lookup = {
1448 dtype: math.pow(
1449 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1450 1 / max(shapeList[0]),
1451 )
1452 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001453 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001454 assert data_range is not None
1455 argsDict["data_range"] = data_range
1456
1457 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001458 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001459 )
1460
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001461 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001462 def tvgResize(
1463 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1464 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001465 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001466 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001467 dtypeList[0],
1468 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1469 )
1470 if data_range:
1471 argsDict["data_range"] = data_range
1472 # Needed for compliance
1473 argsDict["max_abs_value"] = data_range[1]
1474
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001475 scale_values = argsDict["scale"]
1476 offset_values = argsDict["offset"]
1477 border_values = argsDict["border"]
1478 dtypeList[1] = DType.SHAPE
1479 dtypeList[2] = DType.SHAPE
1480 dtypeList[3] = DType.SHAPE
1481 shapeList[1] = [len(scale_values)]
1482 shapeList[2] = [len(offset_values)]
1483 shapeList[3] = [len(border_values)]
1484 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1485
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001486 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001487 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001488 )
1489
Jeremy Johnson30476252023-11-20 16:15:30 +00001490 # Set the POW exponent high data range
1491 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1492 DType.FP32: 10.0,
1493 DType.FP16: 10.0,
1494 DType.BF16: 10.0,
1495 }
1496 # POW highest base value (within a safe margin of error) that can be raised
1497 # to +ve exponent that doesn't become Infinity
1498 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1499 DType.FP32: math.floor(
1500 math.pow(
1501 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1502 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1503 )
1504 ),
1505 DType.FP16: math.floor(
1506 math.pow(
1507 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1508 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1509 )
1510 ),
1511 DType.BF16: math.floor(
1512 math.pow(
1513 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1514 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1515 )
1516 ),
1517 }
1518 # POW lowest base value (within a safe margin of error) that can be raised
1519 # to -ve exponent that doesn't become Infinity
1520 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1521 DType.FP32: math.ceil(
1522 math.pow(
1523 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1524 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1525 )
1526 * 1000
1527 )
1528 / 1000,
1529 DType.FP16: math.ceil(
1530 math.pow(
1531 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1532 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1533 )
1534 * 1000
1535 )
1536 / 1000,
1537 DType.BF16: math.ceil(
1538 math.pow(
1539 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1540 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1541 )
1542 * 1000
1543 )
1544 / 1000,
1545 }
1546
1547 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001548 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001549 if error_name is not None:
1550 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001551 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001552 )
1553 dtype = dtypeList[0]
1554 # Different ranges for POW
1555 test_set = argsDict["s"]
1556 if test_set == 0:
1557 # Positive base with fractional exponent
1558 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001559 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001560 dtype,
1561 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1562 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1563 )
1564 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001565 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001566 )
1567 exp_round = False
1568 else:
1569 # Integer exponent
1570 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001571 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001572 )
1573 exp_round = True
1574 if test_set == 1:
1575 # Positive base
1576 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001577 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001578 dtype,
1579 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1580 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1581 )
1582 else:
1583 assert test_set == 2
1584 # Negative base
1585 # Supply new look up tables with negative values
1586 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001587 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001588 dtype,
1589 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1590 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1591 )
1592
1593 data_range_list = (
1594 {
1595 "range": base_range,
1596 },
1597 {
1598 "range": exp_range,
1599 "round": exp_round,
1600 },
1601 )
1602 argsDict["data_range_list"] = data_range_list
1603 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001604 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001605 )
1606
1607 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001608 def tvgLogRsqrt(
1609 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1610 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001611 # LOG & RSQRT data range from lowest expressible positive number to
1612 # largest to avoid NaNs
1613 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001614 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001615 dtypeList[0],
1616 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1617 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1618 )
1619 if data_range:
1620 argsDict["data_range"] = data_range
1621
1622 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001623 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001624 )
1625
1626 # Set the EXP data range to the log of the largest to smallest values
1627 # to avoid infinities or making the result zero
1628 TVG_FLOAT_HIGH_VALUE_EXP = {
1629 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1630 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1631 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1632 }
1633 TVG_FLOAT_LOW_VALUE_EXP = {
1634 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1635 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1636 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1637 }
1638
1639 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001640 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001641 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001642 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001643 dtypeList[0],
1644 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1645 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1646 )
1647 if data_range:
1648 argsDict["data_range"] = data_range
1649
1650 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001651 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001652 )
1653
1654 @staticmethod
1655 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001656 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001657 ):
1658 dtype = dtypeList[0]
1659 if (
1660 error_name is None
1661 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001662 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001663 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001664 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001665 # Limit ranges for (non error & non compliance) FP tests by using
1666 # values that can be multiplied on any axis to not hit infinity/NaN
1667 IC = shapeList[0][1]
1668 highval_lookup = {
1669 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1670 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001671 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001672 assert data_range is not None
1673 argsDict["data_range"] = data_range
1674
1675 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001676 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001677 )
1678
Jeremy Johnson708da822023-11-15 16:25:45 +00001679 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001680 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001681 in_dtype = dtypeList[0]
1682 out_dtype = argsDict["out_type"]
1683 # Create look up to limit input tensor to output type maximums to avoid
1684 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001685 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001686 highval_lookup = {in_dtype: out_range[1]}
1687 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001688 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001689 in_dtype,
1690 highval_lookup,
1691 )
1692
1693 assert data_range is not None
1694 argsDict["data_range"] = data_range
1695
1696 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001697 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001698 )
1699
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001700 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001701 def tvgGather(
1702 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1703 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001704 K = shapeList[0][1]
1705
1706 # Fix the type of the indices tensor
1707 dtypeList[1] = DType.INT32
1708
1709 dtype = dtypeList[0]
1710 if not gtu.dtypeIsSupportedByCompliance(dtype):
1711 # Test unsupported by data generator
1712 op = testGen.TOSA_OP_LIST[opName]
1713 pCount, cCount = op["operands"]
1714 assert (
1715 pCount == 2 and cCount == 0
1716 ), "Op.GATHER must have 2 placeholders, 0 consts"
1717
1718 tens_ser_list = []
1719 for idx, shape in enumerate(shapeList):
1720 dtype = dtypeList[idx]
1721 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001722 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001723 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1724 else:
1725 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001726 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001727 # To match old functionality - create indices as CONST
1728 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1729
1730 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1731
1732 else:
1733 # ERROR_IF or floating point test
1734 # Use inclusive values upto index K for indices tensor
1735 data_range_list = (
1736 {"range": None},
1737 {"range": (0, K - 1)},
1738 )
1739 argsDict["data_range_list"] = data_range_list
1740
1741 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001742 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001743 )
1744
1745 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001746 def tvgScatter(
1747 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1748 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001749 K = shapeList[0][1]
1750 W = shapeList[2][1]
1751
1752 # Work out an indices tensor here with data that doesn't exceed the
1753 # dimension K of the values_in tensor and does NOT repeat the same K
1754 # location as needed by the spec:
1755 # "It is not permitted to repeat the same output index within a single
1756 # SCATTER operation and so each output index occurs at most once."
1757 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1758
1759 # Fix the type of the indices tensor
1760 dtypeList[1] = DType.INT32
1761
1762 dtype = dtypeList[0]
1763 if not gtu.dtypeIsSupportedByCompliance(dtype):
1764 # Test unsupported by data generator
1765 op = testGen.TOSA_OP_LIST[opName]
1766 pCount, cCount = op["operands"]
1767 assert (
1768 pCount == 3 and cCount == 0
1769 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1770
1771 tens_ser_list = []
1772 for idx, shape in enumerate(shapeList):
1773 dtype = dtypeList[idx]
1774 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001775 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001776 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1777 else:
1778 # Create the indices array
1779 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1780 arr = []
1781 for n in range(shape[0]):
1782 # Get a shuffled list of output indices (0 to K-1) and
1783 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001784 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001785 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1786 # To match old functionality - create indices as CONST
1787 tens_ser_list.append(
1788 testGen.ser.addConst(shape, dtype, indices_arr)
1789 )
1790
1791 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1792
1793 else:
1794 # ERROR_IF or floating point test
1795 # Use inclusive values upto index K for indices tensor
1796 data_range_list = (
1797 {"range": None},
1798 {"range": (0, K - 1)},
1799 {"range": None},
1800 )
1801 argsDict["data_range_list"] = data_range_list
1802
1803 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001804 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001805 )
1806
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001807
1808class TosaArgGen:
1809 """Argument generators create exhaustive or random lists of attributes for
1810 operators that take attributes or other parameters.
1811
1812 The return value is a list of (descriptive_name, [arglist]) tuples where
1813 the descriptive_name is appended to the test name and the arglist is expanded
1814 as arguments to the operator build function.
1815 """
1816
1817 def __init__(self):
1818 pass
1819
1820 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001821 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001822 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001823 if (
1824 error_name is None
1825 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1826 and gtu.dtypeIsSupportedByCompliance(dtype)
1827 ):
Tai Ly60dc48c2024-03-08 22:19:41 +00001828 if gtu.dtypeIsFloat(dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001829 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1830 else:
1831 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1832 else:
1833 # Error test or No data generator types listed - assume random
1834 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1835
1836 # Expand arg list with other data generator types
1837 new_arg_list = []
1838 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001839 for arg_str, args_dict in arg_list:
evacha019c96eef2024-02-07 11:21:55 +00001840
1841 if dg_type == gtu.DataGenType.FULL_RANGE:
1842 tensor_size = gtu.product(shapeList[0])
1843 if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1844 # Large enough tensor data size for full range, add a single test
1845 num_test_sets = 0
1846 else:
1847 # Not enough data size for full range of values, revert to random numbers
1848 dg_type = gtu.DataGenType.PSEUDO_RANDOM
1849
Jeremy Johnson1271c442023-09-05 11:39:26 +01001850 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001851 if error_name is None:
1852 num_test_sets = (
1853 args_dict["num_test_sets"]
1854 if "num_test_sets" in args_dict
1855 else 0
1856 )
1857 else:
evacha019c96eef2024-02-07 11:21:55 +00001858 # Add single test for pseudo random
Jeremy Johnson30476252023-11-20 16:15:30 +00001859 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001860
1861 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1862 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001863 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001864 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001865 shape_info = (
1866 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1867 if "shape" in args_dict
1868 else ""
1869 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001870 logger.info(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001871 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001872 )
1873 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001874 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001875 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001876 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001877
Jeremy Johnson30476252023-11-20 16:15:30 +00001878 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1879
1880 if num_test_sets > 0:
1881 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001882 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
1883 set_args_dict = args_dict.copy()
1884 set_args_dict["s"] = s
1885 set_args_dict["dg_type"] = dg_type
1886 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001887 else:
1888 # Default is a single test
evacha019c96eef2024-02-07 11:21:55 +00001889 new_args_dict = args_dict.copy()
1890 new_args_dict["dg_type"] = dg_type
1891 new_arg_list.append((arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001892
1893 return new_arg_list
1894
1895 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001896 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001897 """A trivial argument generator for operators that don't take any
1898 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001899 arg_list = TosaArgGen._add_data_generators(
1900 testGen,
1901 opName,
evacha019c96eef2024-02-07 11:21:55 +00001902 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001903 dtype,
1904 [("", {})],
1905 error_name,
1906 )
1907 # Return list of tuples: (arg_str, args_dict)
1908 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001909
1910 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001911 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001912 """Pow operator needs different test sets to cover random numbers
1913 without creating NaNs or Infs"""
1914 arg_list = TosaArgGen._add_data_generators(
1915 testGen,
1916 opName,
evacha019c96eef2024-02-07 11:21:55 +00001917 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001918 dtype,
1919 [("", {"num_test_sets": 3})],
1920 error_name,
1921 )
1922 # Return list of tuples: (arg_str, args_dict)
1923 return arg_list
1924
1925 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001926 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001927 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001928 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001929 shape = shapeList[0]
1930
1931 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001932 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001933 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001934 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001935 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001936 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001937 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001938 # Create tests for each dimension
1939 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001940
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001941 opid = testGen.TOSA_OP_LIST[opName]["op"]
1942
1943 for a in axes:
1944 args_dict = {"axis": int(a)}
1945 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001946 output_shape = shape.copy()
1947 if error_name is None:
1948 # It only matters that we calculate the dot_products correctly
1949 # for non error_if tests as they should never be run
1950 output_shape[a] = 1
1951 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001952 args_dict["shape"] = shape
1953 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1954 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1955
1956 arg_list.append(("axis{}".format(a), args_dict))
1957
1958 arg_list = TosaArgGen._add_data_generators(
1959 testGen,
1960 opName,
evacha019c96eef2024-02-07 11:21:55 +00001961 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001962 dtype,
1963 arg_list,
1964 error_name,
1965 )
1966 # Return list of tuples: (arg_str, args_dict)
1967 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001968
1969 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001970 def _calculate_sparsity(num_tests, sparsity_factor):
1971 sparsity = num_tests // sparsity_factor + 1
1972 # If there are only a small number of tests, just select them all
1973 if sparsity < 13:
1974 sparsity = 1
1975 # To get a variety of parameter combinations sparsity should not be a
1976 # multiple of 2, 3 or 5
1977 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1978 sparsity += 1
1979 return sparsity
1980
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001981 # Maximum number of error_if variants to produce
Jeremy Johnson87460262024-03-25 09:46:02 +00001982 MAX_TESTS_ERROR_IFS = 3
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001983
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001984 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001985 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001986 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001987 arg_list = []
1988
Jeremy Johnson0c716862023-04-13 17:18:19 +01001989 if testGen.args.level8k and error_name is not None:
1990 # Don't produce negative large tests
1991 return arg_list
1992
1993 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001994 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001995 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001996 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001997
Tai Lyf36f2562024-03-14 16:21:29 +00001998 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
1999
2000 if error_name == ErrorIf.WrongAccumulatorType:
2001 accum_dtypes = (
2002 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2003 )
James Ward8b390432022-08-12 20:48:56 +01002004
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002005 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01002006 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002007 depthwise = opName.startswith("depthwise")
2008
2009 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01002010 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002011 if error_name != ErrorIf.WrongRank:
2012 assert len(ifm_shape) == rank
2013 assert len(filter_shape) == rank
2014
Jeremy Johnson0c716862023-04-13 17:18:19 +01002015 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002016 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002017 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002018 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002019 # compliance size - KS
2020 k_size = gtu.product(k_shape)
2021 if not depthwise:
2022 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002023
Jeremy Johnson0c716862023-04-13 17:18:19 +01002024 if not testGen.args.level8k:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002025 if error_name in (
2026 ErrorIf.PadSmallerZero,
2027 ErrorIf.StrideSmallerOne,
2028 ErrorIf.DilationSmallerOne,
2029 ):
2030 # Use specific invalid value(s)
2031 if error_name == ErrorIf.PadSmallerZero:
2032 # Create negative paddings but with positive opposite paddings
2033 neg_pad = rng.choice(range(-5, 0))
2034 p_vals = [neg_pad, abs(neg_pad)]
2035 else:
2036 p_vals = [0, 0]
2037 if error_name == ErrorIf.StrideSmallerOne:
2038 # Can't use stride=0, as it is used to derive output shape, as a divisor
2039 s_vals = [rng.choice(range(-5, 0))]
2040 else:
2041 s_vals = [1]
2042 if error_name == ErrorIf.DilationSmallerOne:
2043 d_vals = [rng.choice(range(-5, 1))]
2044 else:
2045 d_vals = [1]
2046 p = p_vals * k_rank
2047 s = s_vals * k_rank
2048 d = d_vals * k_rank
2049
2050 # Fix values to produce valid error_if
2051 for index in range(k_rank):
2052 pad_offset = index * 2
2053 fixed = False
2054 while not fixed:
2055 partial = (
2056 ifm_shape[index + 1]
2057 - 1
2058 + p[pad_offset]
2059 + p[pad_offset + 1]
2060 - (k_shape[index] - 1) * d[index]
2061 )
2062 remainder = partial % s[index]
2063 if partial <= 0:
2064 p[pad_offset + 1] += abs(partial) + 1
2065 elif remainder:
2066 # Stride will be negative for StrideSmallerOne
2067 assert remainder < 0
2068 p[pad_offset + 1] += abs(remainder)
2069 else:
2070 fixed = True
2071 paddings = {tuple(p)}
2072 strides = {tuple(s)}
2073 dilations = {tuple(d)}
2074 logger.debug(f"agConv: error pad={p} stride={s} dilation={d}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01002075 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002076 # Generate comprehensive argument lists
Jeremy Johnson0c716862023-04-13 17:18:19 +01002077 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002078 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
Jeremy Johnson0c716862023-04-13 17:18:19 +01002079 # Stride must be greater than 1 to force non-integer error
2080 startStride = (
2081 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002082 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002083 s_vals = [
2084 x for x in range(startStride, testGen.args.max_conv_stride + 1)
2085 ]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002086 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002087
2088 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2089 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002090
Jeremy Johnson0c716862023-04-13 17:18:19 +01002091 if not error_name and testGen.args.oversize:
2092 # add some oversize argument values
2093 if max(ifm_shape) < 64:
2094 bigPadding = 9
2095 paddings.update(
2096 {
2097 x
2098 for x in itertools.product(
2099 *([[0, bigPadding]] * (k_rank * 2))
2100 )
2101 }
2102 )
2103 bigStride = 8
2104 strides.update(
2105 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2106 )
2107 bigDilation = 7
2108 dilations.update(
2109 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2110 )
2111 max_dim_size = None
2112
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002113 if error_name:
2114 # Cycle through all error_if tests but we only keep the first few
2115 sparsity = 1
2116 else:
2117 # There are too many parameter combinations, so generate them sparsely,
2118 sparsity_factor = 120
2119 sparsity = TosaArgGen._calculate_sparsity(
2120 len(paddings) * len(strides) * len(dilations), sparsity_factor
2121 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002122 else:
2123 # Only test 8k levels boundaries
2124 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2125 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2126 bigPadding = bigKernel
2127
2128 dilation_shape = [1] * k_rank
2129 pad_shape = [0] * k_rank * 2
2130 if conv3d:
2131 # Small stride apart from for big kernel (see below) to keep
2132 # tensor size/calculation small
2133 stride_shape = [1] * k_rank
2134 for idx in range(k_rank):
2135 pad_offset = idx * 2
2136 if k_shape[idx] == bigKernel:
2137 # Padding shape needs to account for tensor shape
2138 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2139 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2140 # Big stride to reduce output size
2141 stride_shape[idx] = bigKernel
2142 else:
2143 # Account for kernel size
2144 pad_shape[pad_offset] = k_shape[idx] - 1
2145 else:
2146 # Always have a large stride with extra padding and dilation to keep
2147 # tensor calculation reasonable
2148 stride_shape = [bigKernel] * k_rank
2149 for idx in range(k_rank):
2150 # Dilation shape must account for kernel size
2151 dilation_shape[idx] = bigKernel // k_shape[idx]
2152 # Padding shape needs to accommodate tensor/kernel & dilation
2153 pad_offset = idx * 2
2154 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2155 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2156
2157 strides = {tuple(stride_shape)}
2158 dilations = {tuple(dilation_shape)}
2159 paddings = {tuple(pad_shape)}
2160 # Create a limit for the output dimensions size
2161 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2162
2163 # Currently allow all combinations that are reasonable size
2164 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002165
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002166 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002167 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002168 for a in accum_dtypes:
2169 for s in sorted(list(strides)):
2170 for p in sorted(list(paddings)):
2171 for d in sorted(list(dilations)):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002172 if (
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002173 more_tests
2174 and n % sparsity == 0
Tai Lyf36f2562024-03-14 16:21:29 +00002175 # the padded shape must exceed the dilation * kernel to get a positive
2176 # sized output shape
2177 and (ifm_shape[1] - 1 + p[0] + p[1])
2178 > d[0] * (k_shape[0] - 1)
2179 and (ifm_shape[2] - 1 + p[2] + p[3])
2180 > d[1] * (k_shape[1] - 1)
2181 and (
2182 k_rank < 3
2183 or (
2184 (ifm_shape[3] - 1 + p[4] + p[5])
2185 > d[2] * (k_shape[2] - 1)
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002186 )
2187 )
Tai Lyf36f2562024-03-14 16:21:29 +00002188 ):
2189 remainders = []
2190 outputs = []
2191 for index in range(k_rank):
2192 pad_offset = index * 2
2193 partial = (
2194 ifm_shape[index + 1]
2195 - 1
2196 + p[pad_offset]
2197 + p[pad_offset + 1]
2198 - (k_shape[index] - 1) * d[index]
2199 )
2200 remainders.append(partial % s[index])
2201 outputs.append((partial // s[index]) + 1)
2202
2203 if (
2204 # the parameters must produce integer exact output
2205 error_name != ErrorIf.ConvOutputShapeNonInteger
2206 and max(remainders) == 0
2207 ) or (
2208 error_name == ErrorIf.ConvOutputShapeNonInteger
2209 and max(remainders) > 0
2210 ):
2211 if (
2212 max_dim_size is not None
2213 and max(outputs) >= max_dim_size
2214 ):
2215 # Test will consume too much memory - skip it
2216 continue
2217
2218 # Compliance - number of dot product calculations
2219 if depthwise:
2220 # N*OH*OW*C*M
2221 dots = gtu.product(
2222 (ifm_shape[0], *outputs, *filter_shape[2:])
2223 )
2224 else:
2225 # N*OH*OW*OC or N*OD*OH*OW*OC
2226 dots = gtu.product(
2227 (ifm_shape[0], *outputs, filter_shape[0])
2228 )
2229 args_dict = {
2230 "acc_type": a,
2231 "stride": s,
2232 "pad": p,
2233 "dilation": d,
2234 "kernel": k_shape,
2235 "ks": k_size,
2236 "dot_products": dots,
2237 "shape": ifm_shape,
2238 }
2239
2240 # Support for larger values than 9 needs different delimiter
2241 delim = "" if max(s + p + d) <= 9 else "x"
2242 arg_list.append(
2243 (
2244 "acc{}_st{}_pad{}_dilat{}".format(
2245 testGen.typeStr(a),
2246 delim.join([str(x) for x in s]),
2247 delim.join([str(x) for x in p]),
2248 delim.join([str(x) for x in d]),
2249 ),
2250 args_dict,
2251 )
2252 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002253 if (
2254 error_name
Jeremy Johnson87460262024-03-25 09:46:02 +00002255 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002256 ):
2257 # Found enough errors
2258 logger.debug(
2259 f"Skipping creating more conv error tests for {error_name}"
2260 )
2261 more_tests = False
Tai Lyf36f2562024-03-14 16:21:29 +00002262 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002263
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002264 arg_list = TosaArgGen._add_data_generators(
2265 testGen,
2266 opName,
evacha019c96eef2024-02-07 11:21:55 +00002267 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002268 dtypes[0],
2269 arg_list,
2270 error_name,
2271 )
2272 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002273 return arg_list
2274
2275 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002276 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002277
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002278 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002279 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002280
2281 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002282 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002283 elif error_name == ErrorIf.WrongInputType:
2284 # Pick some potentially correct output dtype if input type is incorrect
2285 accum_dtype = DType.INT32
2286 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002287 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002288
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002289 # Set up compliance info
2290 args_dict = {
2291 "acc_type": accum_dtype,
2292 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2293 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2294 "shape": shapeList[0],
2295 }
2296
2297 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2298
2299 arg_list = TosaArgGen._add_data_generators(
2300 testGen,
2301 opName,
evacha019c96eef2024-02-07 11:21:55 +00002302 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002303 input_dtype,
2304 arg_list,
2305 error_name,
2306 )
2307 # Return list of tuples: (arg_str, args_dict)
2308 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002309
2310 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002311 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002312 # Get valid accumulate type(s)
2313 if dtype == DType.INT8:
2314 accum_dtypes = [DType.INT32]
2315 elif dtype == DType.INT16:
2316 accum_dtypes = [DType.INT48]
2317 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002318 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002319 elif dtype == DType.BF16:
2320 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002321 elif dtype == DType.FP32:
2322 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002323 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2324 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002325 elif error_name is None:
2326 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2327
2328 if error_name == ErrorIf.WrongOutputType:
2329 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002330 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002331 elif error_name == ErrorIf.WrongInputType:
2332 # Pick some potentially correct output dtype if input type is incorrect
2333 accum_dtypes = [DType.INT32]
2334
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002335 # Set up compliance info
2336 args_dict = {
2337 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2338 # Set dot_products = N*H*W
2339 "dot_products": gtu.product(
2340 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2341 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002342 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002343 }
2344
2345 # Create arg tuple of string and dict
2346 arg_list = []
2347 for a in accum_dtypes:
2348 d = args_dict.copy()
2349 d["acc_type"] = a
2350 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002351
2352 arg_list = TosaArgGen._add_data_generators(
2353 testGen,
2354 opName,
evacha019c96eef2024-02-07 11:21:55 +00002355 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002356 dtype,
2357 arg_list,
2358 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002359 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002360 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002361 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002362
2363 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002364 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002365 arg_list = []
2366
Jeremy Johnson0c716862023-04-13 17:18:19 +01002367 if testGen.args.level8k and error_name is not None:
2368 # Don't produce negative large tests
2369 return arg_list
2370
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002371 ifm_shape = shapeList[0]
2372 filter_shape = shapeList[1]
2373
Tai Lyf36f2562024-03-14 16:21:29 +00002374 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2375
2376 if error_name == ErrorIf.WrongAccumulatorType:
2377 accum_dtypes = (
2378 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2379 )
James Ward8b390432022-08-12 20:48:56 +01002380
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002381 # Must be rank 4
2382 if error_name != ErrorIf.WrongRank:
2383 assert len(ifm_shape) == 4
2384 assert len(filter_shape) == 4
2385
Jeremy Johnson0c716862023-04-13 17:18:19 +01002386 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002387 # compliance size - KS
2388 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002389
Jeremy Johnson0c716862023-04-13 17:18:19 +01002390 if not testGen.args.level8k:
2391 # Generate comprehensive argument lists
2392 # - except for named errors, which use specific invalid value(s)
2393 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2394 if error_name == ErrorIf.PadLargerEqualKernel:
2395 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002396 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002397 else:
2398 p_vals = [
2399 x
2400 for x in range(
2401 smallest_padding_size, testGen.args.max_conv_padding + 1
2402 )
2403 ]
2404 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2405 if error_name == ErrorIf.StrideSmallerOne:
2406 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002407 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002408 else:
2409 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2410 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002411
Jeremy Johnson0c716862023-04-13 17:18:19 +01002412 if not error_name and testGen.args.oversize:
2413 # add some oversize argument values
2414 if max(ifm_shape) < 64:
2415 bigPadding = 9
2416 paddings.update(
2417 {
2418 x
2419 for x in itertools.product(
2420 *([[smallest_padding_size, bigPadding]] * 4)
2421 )
2422 }
2423 )
2424 bigStride = 8
2425 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2426
2427 # There are too many parameter combinations, so generate them sparsely,
2428 # very sparse for negative tests
2429 sparsity_factor = 2 if error_name else 10
2430 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2431 # If there are only a small number of tests, just select them all
2432 if sparsity < 13:
2433 sparsity = 1
2434 # To get a variety of parameter combinations sparsity should not be a
2435 # multiple of 2, 3 or 5
2436 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2437 sparsity += 1
2438 else:
2439 # Only test 8k levels boundaries
2440 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2441 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2442 bigPadding = bigKernel
2443
2444 pad_shape = [0] * (len(k_shape) * 2)
2445 stride_shape = [1] * len(k_shape)
2446 # The point at which input dimension combined with the stride will
2447 # create large output sizes!
2448 LARGE_SIZE = 2
2449 for idx in range(len(k_shape)):
2450 pad_offset = idx * 2
2451 if k_shape[idx] == bigKernel:
2452 # Set large stride
2453 stride_shape[idx] = bigKernel
2454 # Use negative output padding to reduce shape size
2455 pad_shape[pad_offset] = -(bigPadding - 1)
2456 if ifm_shape[idx + 1] > LARGE_SIZE:
2457 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2458 else:
2459 # The other dimension should be the bigKernel
2460 alt_idx = 1 - idx
2461 if (
2462 k_shape[alt_idx] == bigKernel
2463 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2464 ):
2465 # As the input is small, the large stride won't
2466 # affect the output so we can add some padding
2467 pad_shape[pad_offset + 1] = bigPadding
2468
2469 strides = {tuple(stride_shape)}
2470 paddings = {tuple(pad_shape)}
2471
2472 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002473 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002474
2475 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002476 for a in accum_dtypes:
2477 for s in sorted(list(strides)):
2478 for p in sorted(list(paddings)):
2479 if n % sparsity == 0:
2480 # Determine the output shape
2481 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2482 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2483 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002484
Tai Lyf36f2562024-03-14 16:21:29 +00002485 # N*OH*OW*OC
2486 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2487 args_dict = {
2488 "acc_type": a,
2489 "stride": s,
2490 "pad": p,
2491 "kernel": k_shape,
2492 "ks": k_size,
2493 "dot_products": dots,
2494 "shape": ifm_shape,
2495 "out_shape": os,
2496 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002497
Tai Lyf36f2562024-03-14 16:21:29 +00002498 # Support for larger values than 9 needs different delimiter
2499 delim = "" if max(s + p) <= 9 else "x"
2500 arg_list.append(
2501 (
2502 "acc{}_st{}_pad{}_os{}".format(
2503 testGen.typeStr(a),
2504 delim.join([str(x) for x in s]),
2505 delim.join([str(x) for x in p]),
2506 "x".join([str(x) for x in os]),
2507 ),
2508 args_dict,
2509 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002510 )
Tai Lyf36f2562024-03-14 16:21:29 +00002511 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002512
Jeremy Johnson95a67102024-01-10 14:16:39 +00002513 arg_list = TosaArgGen._add_data_generators(
2514 testGen,
2515 opName,
evacha019c96eef2024-02-07 11:21:55 +00002516 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002517 dtypes[0],
2518 arg_list,
2519 error_name,
2520 )
2521 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002522 return arg_list
2523
2524 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002525 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002526 rank = len(shapeList[0])
2527
2528 # Exhaustively test combinations of padding on each side of each dimension
2529 # - the range of padding values is defined by pad_min and pad_max
2530 # - for padding >9, the name format needs to be more distinctive
2531 pad_min, pad_max = 0, 1
2532 pad_values = [x for x in range(pad_min, pad_max + 1)]
2533 if error_name == ErrorIf.PadSmallerZero:
2534 pad_values = [x for x in range(-2, 0)]
2535 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2536 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2537
2538 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002539 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002540 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002541 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002543 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002544 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002545 assert error_name == ErrorIf.WrongInputType
2546 pad_const_int = 0
2547 pad_const_fp = 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002548
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002549 list_shape_pad_values = list(shape_pad_values)
2550 # If we are producing tests for rank 6 or greater use sparsity
2551 if len(list_shape_pad_values) > 1024:
2552 sparsity_factor = 2 if error_name else 120
2553 sparsity = TosaArgGen._calculate_sparsity(
2554 len(list_shape_pad_values), sparsity_factor
2555 )
2556 else:
2557 sparsity = 1
2558
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002559 # Build arg list
2560 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002561 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002562 paddings = list(paddings)
2563 args_valid = True
2564
2565 if error_name == ErrorIf.PadSmallerZero:
2566 # Prevent negative output shapes while ensuring still testing for negative padding
2567 for i in range(rank):
2568 dim_after_padding = (
2569 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2570 )
2571 if dim_after_padding < 1:
2572 paddings[i] = (0, 0)
2573 if all([p > -1 for p in paddings[i]]):
2574 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002575 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002576 name = "pad"
2577 for r in range(rank):
2578 before, after = paddings[r]
2579 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002580 args_dict = {
2581 "pad": np.array(paddings),
2582 "pad_const_int": pad_const_int,
2583 "pad_const_fp": pad_const_fp,
2584 }
2585 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002586
2587 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002588 logger.debug(
2589 f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
2590 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002591
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002592 arg_list = TosaArgGen._add_data_generators(
2593 testGen,
2594 opName,
evacha019c96eef2024-02-07 11:21:55 +00002595 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002596 dtype,
2597 arg_list,
2598 error_name,
2599 )
2600
2601 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002602 return arg_list
2603
2604 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002605 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002606 arg_list = []
2607
2608 shape = shapeList[0]
2609 if error_name != ErrorIf.WrongRank:
2610 assert len(shape) == 4
2611
Jeremy Johnson0c716862023-04-13 17:18:19 +01002612 test_level8k = testGen.args.level8k and error_name is None
2613
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002614 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002615 startKernel = 2
2616 startPad = 0
2617 if not test_level8k:
2618 # Generate comprehensive argument lists
2619 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2620 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2621 # Stride must be greater than 1 to force non-integer error
2622 s_vals = [
2623 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2624 ]
2625 strides = {x for x in itertools.product(*([s_vals] * 2))}
2626 k_vals = [
2627 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2628 ]
2629 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2630 max_dim_size = None
2631 else:
2632 # Only test 8k levels
2633 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2634 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2635 strides = {(1, bigStride), (bigStride, 4)}
2636 kernels = {(1, bigKernel), (bigKernel, 3)}
2637 paddings = set()
2638 for s in sorted(list(strides)):
2639 for k in sorted(list(kernels)):
2640 padding = []
2641 for idx in range(len(k)):
2642 total_padding = s[idx] - shape[idx + 1] + k[idx]
2643 while total_padding < 0:
2644 # Must meet: shape + padding > kernel
2645 total_padding += s[idx]
2646 if total_padding < k[idx]:
2647 padding.extend([0, total_padding])
2648 else:
2649 # Note this may produce padding >= k[idx] which is not
2650 # allowed - but will be ignored in the creation loop below
2651 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2652 paddings.add(tuple(padding))
2653 # Create a limit for the output dimensions size
2654 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002655
James Ward8b390432022-08-12 20:48:56 +01002656 if opName == "max_pool2d":
2657 accum_dtypes = [None] # max_pool has no accumulate dtype
2658 elif dtype == DType.INT8 or dtype == DType.INT16:
2659 accum_dtypes = [DType.INT32]
2660 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002661 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002662 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002663 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002664 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2665 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002666 elif error_name is None:
2667 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2668 else:
2669 # Set to something for the ErrorIf case which has
2670 # incorrect input data-type
2671 accum_dtypes = [DType.INT32]
2672
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002673 if error_name == ErrorIf.WrongAccumulatorType:
2674 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2675
Jeremy Johnson0c716862023-04-13 17:18:19 +01002676 if not test_level8k:
2677 if testGen.args.oversize:
2678 # add some oversize argument values
2679 bigStride = 7
2680 bigKernel = 9
2681 strides.update(
2682 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002683 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002684 kernels.update(
2685 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2686 )
2687 if max(shape) < 64:
2688 # padding must be less than the kernel size
2689 bigPadding = bigKernel - 1
2690 paddings.update(
2691 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2692 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002693
Jeremy Johnson87460262024-03-25 09:46:02 +00002694 if error_name:
2695 # Cycle through all error_if tests but we only keep the first few
2696 sparsity = 1
2697 else:
2698 # There are too many parameter combinations, so generate them sparsely
2699 sparsity_factor = 500
2700 sparsity = (
2701 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2702 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002703 else:
2704 # We have already limited test output combinations for 8k tests
2705 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002706
James Ward8b390432022-08-12 20:48:56 +01002707 arg_str = (
2708 "acc{}_st{}_kern{}_pad{}"
2709 if accum_dtypes[0] is not None
2710 else "st{}_kern{}_pad{}"
2711 )
2712
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002713 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002714 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002715 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002716
2717 # Support for larger values than 9 needs different delimiter
2718 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002719 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002720 delim.join([str(x) for x in stride]),
2721 delim.join([str(x) for x in kern]),
2722 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002723 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002724 args_dict = {
2725 "stride": stride,
2726 "pad": pad,
2727 "kernel": kern,
2728 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002729 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002730 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2731 }
James Ward8b390432022-08-12 20:48:56 +01002732
2733 if accum is not None:
2734 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002735 args_dict["acc_type"] = accum
2736 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002737
Jeremy Johnson87460262024-03-25 09:46:02 +00002738 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002739 n = 0
James Ward8b390432022-08-12 20:48:56 +01002740 for a in accum_dtypes:
2741 for s in sorted(list(strides)):
2742 for p in sorted(list(paddings)):
2743 for k in sorted(list(kernels)):
2744 if error_name in [
2745 ErrorIf.StrideSmallerOne,
2746 ErrorIf.KernelSmallerOne,
2747 ErrorIf.PadSmallerZero,
2748 ErrorIf.PadLargerEqualKernel,
2749 ]:
2750 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002751 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002752 )
James Ward8b390432022-08-12 20:48:56 +01002753 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002754 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002755 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002756 )
James Ward8b390432022-08-12 20:48:56 +01002757 elif (
Jeremy Johnson87460262024-03-25 09:46:02 +00002758 more_tests
2759 and n % sparsity == 0
James Ward8b390432022-08-12 20:48:56 +01002760 # padding must not exceed the kernel size
2761 and p[0] < k[0]
2762 and p[1] < k[0]
2763 and p[2] < k[1]
2764 and p[3] < k[1]
2765 # the padded shape must exceed the kernel size
2766 and (shape[1] + p[0] + p[1]) > k[0]
2767 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002768 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002769 partial_h = shape[1] + p[0] + p[1] - k[0]
2770 partial_w = shape[2] + p[2] + p[3] - k[1]
2771 remainder_h = partial_h % s[0]
2772 remainder_w = partial_w % s[1]
2773 output_h = partial_h // s[0] + 1
2774 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002775 logger.debug(
2776 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2777 )
James Ward8b390432022-08-12 20:48:56 +01002778 if (
2779 # the parameters must produce integer exact output
2780 error_name != ErrorIf.PoolingOutputShapeNonInteger
2781 and remainder_h == 0
2782 and remainder_w == 0
2783 ) or (
2784 error_name == ErrorIf.PoolingOutputShapeNonInteger
2785 and (remainder_h != 0 or remainder_w != 0)
2786 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002787 if (
2788 max_dim_size is not None
2789 and max(output_h, output_w) > max_dim_size
2790 ):
2791 # Test will consume too much memory - skip it
2792 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002793 # Dot products = N*OH*OW*C
2794 dp = gtu.product(
2795 (shape[0], output_h, output_w, shape[3])
2796 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002797 arg_list.append(
2798 get_arg_list_element(a, s, p, k, dp, shape)
2799 )
Jeremy Johnson87460262024-03-25 09:46:02 +00002800 if (
2801 error_name
2802 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
2803 ):
2804 # Found enough errors
2805 logger.debug(
2806 f"Skipping creating more pooling error tests for {error_name}"
2807 )
2808 more_tests = False
2809
James Ward8b390432022-08-12 20:48:56 +01002810 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002811
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002812 # Now add data generator types
2813 arg_list = TosaArgGen._add_data_generators(
2814 testGen,
2815 opName,
evacha019c96eef2024-02-07 11:21:55 +00002816 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002817 dtype,
2818 arg_list,
2819 error_name,
2820 )
2821
2822 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002823 return arg_list
2824
2825 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002826 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002827 arg_list = []
2828
2829 # Enumerate the output types here
2830 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002831 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002832 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002833 dtypeList = [
2834 DType.BOOL,
2835 DType.INT16,
2836 DType.INT32,
2837 DType.FP16,
2838 DType.BF16,
2839 DType.FP32,
2840 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002841 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002842 dtypeList = [
2843 DType.BOOL,
2844 DType.INT8,
2845 DType.INT32,
2846 DType.FP16,
2847 DType.BF16,
2848 DType.FP32,
2849 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002850 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002851 dtypeList = [
2852 DType.BOOL,
2853 DType.INT8,
2854 DType.INT16,
2855 DType.FP16,
2856 DType.BF16,
2857 DType.FP32,
2858 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002859 elif inDtype == DType.BOOL:
2860 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002861 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002862 dtypeList = [
2863 DType.INT8,
2864 DType.INT16,
2865 DType.INT32,
2866 DType.FP32,
2867 DType.FP8E4M3,
2868 DType.FP8E5M2,
2869 ]
James Ward24dbc422022-10-19 12:20:31 +01002870 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002871 dtypeList = [
2872 DType.INT8,
2873 DType.INT16,
2874 DType.INT32,
2875 DType.FP32,
2876 DType.FP8E4M3,
2877 DType.FP8E5M2,
2878 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002879 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002880 dtypeList = [
2881 DType.INT8,
2882 DType.INT16,
2883 DType.INT32,
2884 DType.FP16,
2885 DType.BF16,
2886 DType.FP8E4M3,
2887 DType.FP8E5M2,
2888 ]
2889 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2890 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002891 elif error_name == ErrorIf.WrongInputType:
2892 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002893 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002894 else:
2895 raise Exception("Unexpected input dtype: {}".format(inDtype))
2896
2897 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002898 arg_list.append(
2899 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2900 )
2901
2902 # Now add data generator types
2903 arg_list = TosaArgGen._add_data_generators(
2904 testGen,
2905 opName,
evacha019c96eef2024-02-07 11:21:55 +00002906 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002907 dtype,
2908 arg_list,
2909 error_name,
2910 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002911
2912 return arg_list
2913
2914 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002915 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002916 arg_list = []
2917
2918 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002919 for outDtype in [
2920 DType.UINT8,
2921 DType.INT8,
2922 DType.INT16,
2923 DType.INT32,
2924 DType.UINT16,
2925 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002926 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002927 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002928 and error_name == ErrorIf.OutputZeroPointNotZero
2929 ):
2930 continue
2931 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002932 outDtype != DType.UINT16
2933 and error_name == ErrorIf.U16OutputZeroPointNotValid
2934 ) or (
2935 inDtype != DType.UINT16
2936 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002937 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002938 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002939 continue
2940 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002941 inDtype == DType.UINT8
2942 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002943 and error_name != ErrorIf.WrongOutputType
2944 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002945 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2946 continue
2947 if (
2948 inDtype not in [DType.INT8, DType.INT16]
2949 and outDtype == DType.UINT8
2950 and error_name != ErrorIf.WrongOutputType
2951 ):
2952 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2953 continue
2954 if (
2955 inDtype == DType.UINT16
2956 and outDtype != DType.INT16
2957 and error_name != ErrorIf.WrongOutputType
2958 ):
2959 # The only output dtype for UINT16 is INT16, skip all others
2960 continue
2961 if (
2962 inDtype != DType.INT16
2963 and outDtype == DType.UINT16
2964 and error_name != ErrorIf.WrongOutputType
2965 ):
2966 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002967 continue
2968 if (
2969 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002970 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002971 ):
2972 continue
2973
2974 for scale32 in [False, True]:
2975 if error_name == ErrorIf.ScaleTrue and not scale32:
2976 continue
2977 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2978 continue
2979 for double_round in [False, True]:
2980 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2981 continue
2982 for per_channel in [False, True]:
2983
2984 if (
2985 inDtype == DType.INT48
2986 and scale32
2987 and error_name != ErrorIf.ScaleTrue
2988 ):
2989 # Illegal condition. Must be scale32=False
2990 continue
2991 if (
2992 double_round
2993 and not scale32
2994 and error_name != ErrorIf.ScaleNotTrue
2995 ):
2996 # Illegal condition. ERROR_IF(!scale32 && double_round)
2997 continue
2998
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002999 if per_channel:
3000 nc = shapeList[0][-1]
3001 else:
3002 nc = 1
3003
3004 in_type_width = gtu.dtypeWidth(inDtype)
3005 out_type_width = gtu.dtypeWidth(outDtype)
3006
3007 # Calculate scale based on:
3008 # scale = a *(2^output_width)/(2^input_width))
3009
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003010 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003011 scale_arr = a * np.float32(
3012 (1 << out_type_width) / (1 << in_type_width)
3013 )
3014
3015 if scale32:
3016 # Cap the scaling at 2^31 - 1 for scale32
3017 scale_arr = np.clip(
3018 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
3019 )
3020 else:
3021 # Cap the scaling at 2^15 - 1 for scale16
3022 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3023
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003024 logger.debug(
3025 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
3026 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003027
3028 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3029 shift_arr = np.int32(np.zeros(shape=[nc]))
3030 for i in range(nc):
3031 (
3032 multiplier_arr[i],
3033 shift_arr[i],
3034 ) = TosaQuantGen.computeMultiplierAndShift(
3035 scale_arr[i], scale32
3036 )
3037
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003038 arg_list.append(
3039 (
3040 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01003041 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003042 int(scale32),
3043 int(double_round),
3044 int(per_channel),
3045 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00003046 {
3047 "output_dtype": outDtype,
3048 "scale": scale32,
3049 "double_round": double_round,
3050 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003051 "multiplier": multiplier_arr,
3052 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003053 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003054 )
3055 )
3056
Jeremy Johnson587cc842024-02-08 11:45:44 +00003057 arg_list = TosaArgGen._add_data_generators(
3058 testGen,
3059 opName,
evacha019c96eef2024-02-07 11:21:55 +00003060 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003061 inDtype,
3062 arg_list,
3063 error_name,
3064 )
3065 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003066 return arg_list
3067
3068 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003069 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003070 arg_list = []
3071
3072 if dtype is DType.INT32:
3073 for p in range(testGen.args.num_rand_permutations):
3074
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003075 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003076 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003077 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003078 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003079
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003080 arg_list = TosaArgGen._add_data_generators(
3081 testGen,
3082 opName,
evacha019c96eef2024-02-07 11:21:55 +00003083 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003084 dtype,
3085 arg_list,
3086 error_name,
3087 )
3088 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003089 return arg_list
3090
3091 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003092 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003093 arg_list = []
3094
Jeremy Johnson587cc842024-02-08 11:45:44 +00003095 for round in (True, False):
3096 args_dict = {
3097 "round": round,
3098 }
3099 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003100
Jeremy Johnson587cc842024-02-08 11:45:44 +00003101 arg_list = TosaArgGen._add_data_generators(
3102 testGen,
3103 opName,
evacha019c96eef2024-02-07 11:21:55 +00003104 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003105 dtype,
3106 arg_list,
3107 error_name,
3108 )
3109 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003110 return arg_list
3111
Luke Hutton57287132023-02-06 14:54:18 +00003112 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003113 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003114 arg_list = []
3115
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003116 shape = shapeList[0]
3117 dot_products = gtu.product(shape)
3118 ks = 2 * shape[1] * shape[2] # 2*H*W
3119 for inverse in (True, False):
3120 args_dict = {
3121 "dot_products": dot_products,
3122 "shape": shape,
3123 "ks": ks,
3124 "acc_type": dtype,
3125 "inverse": inverse,
3126 }
3127 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003128
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003129 arg_list = TosaArgGen._add_data_generators(
3130 testGen,
3131 opName,
evacha019c96eef2024-02-07 11:21:55 +00003132 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003133 dtype,
3134 arg_list,
3135 error_name,
3136 )
3137 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003138 return arg_list
3139
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003140 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003141 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003142 arg_list = []
3143
3144 shape = shapeList[0]
3145 dot_products = gtu.product(shape)
3146 ks = shape[1] * shape[2] # H*W
3147 args_dict = {
3148 "dot_products": dot_products,
3149 "shape": shape,
3150 "ks": ks,
3151 "acc_type": dtype,
3152 }
3153 arg_list.append(("", args_dict))
3154
3155 arg_list = TosaArgGen._add_data_generators(
3156 testGen,
3157 opName,
evacha019c96eef2024-02-07 11:21:55 +00003158 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003159 dtype,
3160 arg_list,
3161 error_name,
3162 )
3163 # Return list of tuples: (arg_str, args_dict)
3164 return arg_list
3165
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003166 # Helper function for reshape. Gets some factors of a larger number.
3167 @staticmethod
3168 def getFactors(val, start=1):
3169 factors = []
3170
3171 for i in range(start, int(np.sqrt(val)) + 1):
3172 if (val % i) == 0:
3173 factors.append(i)
3174
3175 return factors
3176
3177 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003178 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003179 arg_list = []
3180
3181 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003182 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003183 factors = TosaArgGen.getFactors(totalElements)
3184
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003185 # Find new shapes up to the number of permutations asked for
3186 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003187 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003188 # Rank from 1 to MAX_TENSOR_RANK
3189 newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003190 if len(factors) < newRank:
3191 continue
3192
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003193 # escape_counter limits the generation of new shapes to a reasonable time
3194 for escape_counter in range(100):
3195
3196 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003197 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003198 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003199 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003200 for i in range(1, newRank):
3201 # pick rank-1 factors
3202 newShape.append(shuffledFactors[0])
3203 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003204 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003205 TosaArgGen.getFactors(remainingElements)
3206 )
3207 newShape.append(remainingElements)
3208
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003209 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003210 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003211 for name, args_dict in arg_list:
3212 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003213 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003214 break
3215
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003216 if not duplicate:
3217 outShape = "x".join([str(x) for x in newShape])
3218 arg_list.append(
3219 (
3220 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3221 {"new_shape": newShape},
3222 )
3223 )
3224 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003225 break
3226
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003227 # Now add data generator types
3228 arg_list = TosaArgGen._add_data_generators(
3229 testGen,
3230 opName,
evacha019c96eef2024-02-07 11:21:55 +00003231 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003232 dtype,
3233 arg_list,
3234 error_name,
3235 )
3236
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003237 return arg_list
3238
3239 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003240 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003241 arg_list = []
3242
3243 ifm_shape = shapeList[0]
3244
3245 if error_name == ErrorIf.IndexOutsideBounds:
3246 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3247 incorrect_small_index = range(-len(ifm_shape), 0)
3248 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3249 permutations.extend(
3250 [p for p in itertools.permutations(incorrect_small_index)]
3251 )
3252 elif error_name == ErrorIf.IndexUsedTwice:
3253 # Create list with a duplicated index
3254 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003255 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003256 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3257 permutations = [p for p in itertools.permutations(perm_range)]
3258
3259 else:
3260 # Get all permutations
3261 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3262
3263 # Limit to possible permutations from shape dimension or argument setting
3264 limit = min(len(permutations), testGen.args.num_rand_permutations)
3265
3266 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003267 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003268
3269 # Create list of required amount of permutations
3270 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003271 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003272 for p in range(limit)
3273 ]
evacha0198477222024-01-26 12:25:32 +00003274 # Now add data generator types
3275 arg_list = TosaArgGen._add_data_generators(
3276 testGen,
3277 opName,
evacha019c96eef2024-02-07 11:21:55 +00003278 shapeList,
evacha0198477222024-01-26 12:25:32 +00003279 dtype,
3280 arg_list,
3281 error_name,
3282 )
3283 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 return arg_list
3285
3286 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003287 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003288 arg_list = []
3289
3290 ifm_shape = shapeList[0]
3291 rank = len(ifm_shape)
3292
3293 for p in range(testGen.args.num_rand_permutations):
3294 start = []
3295 size = []
3296
3297 valid = True
3298
3299 for i in range(rank):
3300 if ifm_shape[i] > 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003301 start.append(rng.randInt(0, ifm_shape[i]))
3302 size.append(rng.randInt(0, ifm_shape[i] - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003303
3304 # Invalid slice size?
3305 if size[i] == 0:
3306 valid = False
3307 else:
3308 start.append(0)
3309 size.append(1)
3310
3311 if valid:
3312 # If ERROR_IF test required then incorrect start, size will be returned
3313 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003314 rng, error_name, ifm_shape, start, size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003315 )
evacha017f7d4252024-01-24 12:08:09 +00003316 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3317 # Now add data generator types
3318 arg_list = TosaArgGen._add_data_generators(
3319 testGen,
3320 opName,
evacha019c96eef2024-02-07 11:21:55 +00003321 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003322 dtype,
3323 arg_list,
3324 error_name,
3325 )
3326 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003327 return arg_list
3328
3329 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003330 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003331 arg_list = []
3332
3333 ifm_shape = shapeList[0]
3334 rank = len(ifm_shape)
3335
3336 for p in range(testGen.args.num_rand_permutations):
3337
3338 # Pick a few random, but small multiple values
3339 # because otherwise this has a tendency to generate
3340 # enormous tensors
3341 multiples = []
3342 for i in range(rank):
3343 if ifm_shape[i] > 1000:
3344 # Multiple of 1 if ifm_shape dimension is large to reduce
3345 # tensor size
3346 multiples.append(1)
3347 elif max(ifm_shape) > 1000:
3348 multiples.append(2)
3349 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003350 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003351 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003352
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003353 # Now add data generator types
3354 arg_list = TosaArgGen._add_data_generators(
3355 testGen,
3356 opName,
evacha019c96eef2024-02-07 11:21:55 +00003357 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003358 dtype,
3359 arg_list,
3360 error_name,
3361 )
3362 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003363 return arg_list
3364
3365 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003366 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003367 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003368 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003369
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003370 def get_aspect_ratio_resize_params():
3371 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003372 aspect_ratio = rng.choice(common_aspect_ratios)
3373 invert = rng.choice((False, True))
3374 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003375
3376 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3377 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3378 scale_y_d = scale_x_d = 1
3379 offset_x = offset_y = 0
3380
3381 if letterbox:
3382 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003383 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003384 border_x = 0
3385 else:
3386 # Pillarboxing
3387 border_y = 0
3388 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003389 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003390
3391 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3392 offset = (offset_y, offset_x)
3393 border = (border_y, border_x)
3394
3395 return scale, offset, border
3396
3397 def get_upscale_downscale_params():
3398 valid_params = False
3399 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003400 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003401
3402 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003403 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003404
3405 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003406 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003407 scale_x_d = scale_y_d = 1
3408 scale_x_n = scale_y_n = (
3409 1 << shift if origin_sampling else 2 << shift
3410 )
3411 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3412 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3413 else:
3414 scale_x_n = 1
3415 scale_y_n = 1
3416
3417 # Return list of valid scale_*_d values (max value 4) given input dim shape
3418 def get_valid_denom(ifm_dim):
3419 return [x for x in range(1, 5) if ifm_dim % x == 1]
3420
3421 # Generate list of valid downscale values and choose one randomly
3422 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3423 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3424
3425 if not valid_scale_y_ds and not valid_scale_x_ds:
3426 # Bad parameters, skip
3427 continue
3428
3429 if not valid_scale_y_ds:
3430 scale_y_d = 1
3431 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003432 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003433
3434 if not valid_scale_x_ds:
3435 scale_x_d = 1
3436 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003437 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003438
3439 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003440 offset_y = rng.randInt(0, 16 * scale_y_n)
3441 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003442 valid_params = True
3443
3444 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3445 offset = (offset_y, offset_x)
3446 border = (border_y, border_x)
3447 return scale, offset, border
3448
3449 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003450 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3451 scale = scale_n / scale_d
3452 if scale > max_scale:
3453 factor = scale / max_scale
3454 new_scale_d = math.ceil(scale_d * factor)
3455 assert scale_n / new_scale_d <= max_scale
3456 scale_d = new_scale_d
3457 return scale_d
3458
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003459 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003460 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3461 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003462
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003463 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3464 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003465
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003466 scale_y_d = fix_scale_to_max_scale(
3467 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3468 )
3469 scale_x_d = fix_scale_to_max_scale(
3470 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3471 )
3472
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003473 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003474 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3475 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3476 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3477 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003478
3479 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3480 offset = (offset_y, offset_x)
3481 border = (border_y, border_x)
3482 return scale, offset, border
3483
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003484 def get_level_8k_params():
3485 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003486 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003487 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3488 )
3489 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3490 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003491 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003492 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003493 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003494 if switch:
3495 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3496 else:
3497 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3498
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003499 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3500 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003501 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003502 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3503 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003504 border = (border_y, border_x)
3505 return scale, offset, border
3506
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003507 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003508 # Exclude illegal {mode, type} configurations. Pick legal output types
3509 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3510 outputDTypeList = [DType.INT8]
3511 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3512 outputDTypeList = [DType.INT16]
3513 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3514 outputDTypeList = [DType.INT32]
3515 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3516 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003517 elif dtype == DType.FP16:
3518 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003519 elif dtype == DType.BF16:
3520 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003521 elif dtype == DType.FP32:
3522 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003523 elif dtype == DType.FP8E4M3:
3524 outputDTypeList = [DType.FP8E4M3]
3525 elif dtype == DType.FP8E5M2:
3526 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003527 elif error_name == ErrorIf.WrongInputType:
3528 # If an incorrect input type is used then we set a 'correct'
3529 # output type to avoid other errors
3530 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3531 else:
3532 continue
3533
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003534 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3535
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003536 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003537 perm = 0
3538 while perm < testGen.args.num_rand_permutations:
3539 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003540 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003541 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003542 (
3543 get_rand_params,
3544 get_upscale_downscale_params,
3545 get_aspect_ratio_resize_params,
3546 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003547 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003548 scale, offset, border = _rnd_param_fn()
3549 else:
3550 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003551
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003552 # Expand params for bounds-checking
3553 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3554 (offset_y, offset_x) = offset
3555 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003556
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003557 # Make sure output dimensions OH and OW are integers
3558 partial_output_y = (
3559 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3560 )
3561 partial_output_x = (
3562 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3563 )
3564 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003565 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003566 if (
3567 partial_output_y % scale_y_d == 0
3568 and partial_output_x % scale_x_d == 0
3569 ):
3570 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003571 if perm > 0:
3572 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003573 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003574 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003575 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003576 while partial_output_y % scale_y_d != 0:
3577 scale_y_d -= 1
3578 while partial_output_x % scale_x_d != 0:
3579 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003580 # Make sure we are still within max scaling
3581 if (
3582 scale_y_n / scale_y_d
3583 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3584 scale_x_n / scale_x_d
3585 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3586 # Skip the test as it is using too large a scaling factor
3587 if perm > 0:
3588 perm += 1
3589 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003590
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003591 output_y = partial_output_y // scale_y_d + 1
3592 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003593
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003594 if (
3595 output_y >= testGen.args.max_resize_output_dim
3596 or output_x >= testGen.args.max_resize_output_dim
3597 ) and error_name is None:
3598 # Skip positive test if output dim will be too high
3599 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003600 if not testGen.args.level8k or perm > 0:
3601 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003602 continue
3603
3604 if (
3605 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003606 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003607 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003608 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003609 ):
3610 # Output dimensions out of scope
3611 if error_name is not None and perm > 0:
3612 # As long as we have one ERROR_IF test, don't worry
3613 # about creating all the other permutations
3614 perm += 1
3615 continue
3616
3617 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3618 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003619 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003620 and output_y - scale_y_d < 1
3621 )
3622 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003623 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003624 and output_x - scale_x_d < 1
3625 )
3626 ):
3627 # Can't create a negative test with these params as it
3628 # will create invalid output size
3629 if perm > 0:
3630 perm += 1
3631 continue
3632
3633 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3634 offset = [offset_y, offset_x]
3635 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003636
3637 # Common for all data types
3638 if error_name is not None:
3639 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003640 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003641 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003642 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003643 outputDTypeNew,
3644 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003645 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003646 error_name,
3647 mode,
3648 dtype,
3649 shapeList,
3650 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003651 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003652 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003653 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003654 )
3655 else:
3656 outputDTypeNew = outputDType
3657
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003658 arg_to_append = (
3659 arg_str.format(
3660 "N" if mode == ResizeMode.NEAREST else "B",
3661 testGen.typeStr(outputDTypeNew),
3662 scale[0],
3663 scale[1],
3664 scale[2],
3665 scale[3],
3666 offset[0],
3667 offset[1],
3668 border[0],
3669 border[1],
3670 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003671 {
3672 "mode": mode,
3673 "scale": scale,
3674 "offset": offset,
3675 "border": border,
3676 "output_dtype": outputDTypeNew,
3677 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003678 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003679 if arg_to_append in arg_list:
3680 # Skip already generated test params
3681 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003682
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003683 # Valid permutation
3684 perm += 1
3685 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003686
3687 # Now add data generator types
3688 arg_list = TosaArgGen._add_data_generators(
3689 testGen,
3690 opName,
evacha019c96eef2024-02-07 11:21:55 +00003691 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003692 dtype,
3693 arg_list,
3694 error_name,
3695 )
3696 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 return arg_list
3698
3699 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003700 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003701 arg_list = []
3702
3703 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003704 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003705 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003706 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003707 # Make sure all slopes are within REQUIRE min/max 16-bit int
3708 for idx in range(len(table) - 1):
3709 slope = table[idx + 1] - table[idx]
3710 # Alter the next table entry to force the slope to be ok
3711 if slope > 32767:
3712 table[idx + 1] -= slope - 32767
3713 if slope < -32768:
3714 table[idx + 1] -= slope + 32768
3715 slope = table[idx + 1] - table[idx]
3716 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003717 arg_list.append(
3718 (
3719 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003720 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003721 )
3722 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003723 # Now add data generator types
3724 arg_list = TosaArgGen._add_data_generators(
3725 testGen,
3726 opName,
evacha019c96eef2024-02-07 11:21:55 +00003727 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003728 dtype,
3729 arg_list,
3730 error_name,
3731 )
3732 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003733 return arg_list
3734
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003735 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003736 # CondIf generates the condition values here.
3737 # Convert to tensors in the build function, along with the
3738 # then and else blocks
3739 arg_list = []
3740
3741 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003742 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003743
Jeremy Johnson587cc842024-02-08 11:45:44 +00003744 # Now add data generator types
3745 arg_list = TosaArgGen._add_data_generators(
3746 testGen,
3747 opName,
evacha019c96eef2024-02-07 11:21:55 +00003748 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003749 dtype,
3750 arg_list,
3751 error_name,
3752 )
3753 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003754 return arg_list
3755
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003756 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003757 # While loop: 0 iterations, 1, more than 1
3758 arg_list = []
3759
Jeremy Johnson587cc842024-02-08 11:45:44 +00003760 for iterations in [0, 1, 4]:
3761 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003762
Jeremy Johnson587cc842024-02-08 11:45:44 +00003763 # Now add data generator types
3764 arg_list = TosaArgGen._add_data_generators(
3765 testGen,
3766 opName,
evacha019c96eef2024-02-07 11:21:55 +00003767 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003768 dtype,
3769 arg_list,
3770 error_name,
3771 )
3772 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003773 return arg_list