blob: 8d6c8d79b2782f0866e8ddbb0f94df75fc18d612 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000262 if rank == 0:
263 # No broadcasting possible for rank 0
264 return [[]] * num_shapes
265
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100266 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100267 shape_list = []
268
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000269 # Choose any one of the inputs to broadcast
270 # Note for ERRORS: Simplifies OutputShaper code if we don't change first shape
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100271 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
272 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000273
Jeremy Johnson0a042992024-02-28 13:20:05 +0000274 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100275 shape_bcast = shape.copy()
276
Jerry Ge135c9552023-05-23 20:59:32 +0000277 # To test broadcasting, the chosen fuzz index dimension should not be 1
278 if shape_bcast[fuzz_idx] == 1:
279 shape_bcast[fuzz_idx] += 1
280
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100281 # If the chosen input, pick a random index to broadcast
282 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000283 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100285 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100286 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100287 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 )
289 if rank != 1:
290 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100291 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100292 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000293 elif error_name == ErrorIf.BroadcastShapesMismatch:
294 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100295 else:
296 shape_bcast[fuzz_idx] = 1
297
298 shape_list.append(shape_bcast)
299
300 return shape_list
301
302 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100303 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000304 pl, const = op["operands"]
305 num_shapes = pl + const
306 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100307 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000308 )
309
310 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100311 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000312 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100313 shape_list = TosaTensorGen._get_broadcast_shapes(
314 testGen, rng, 2, rank, error_name
315 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000316 # Add a single dimension tensor for shift
317 shape_list.append([1])
318 return shape_list
319
320 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100321 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100322 pl, const = op["operands"]
323
324 if error_name != ErrorIf.WrongRank:
325 assert rank == 4
326
327 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100328 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000329 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330
331 # Constrict the overall size of the shape when creating ERROR_IF tests
332 if error_name:
333 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
334 ifm_shape, max_dim=24, max_items=10000
335 )
336
337 # Get the filter height/width from the operator parameters
338 filter_hw = op["filter"]
339
340 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100341 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100342
343 # The filter dimensions are OHWI
344 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
345
Jeremy Johnson5e36bde2024-03-14 16:56:10 +0000346 # The bias is OC or 1 if broadcastable
347 try:
348 if op["broadcastable_bias"]:
349 if rng.choice([True, False]):
350 ofm_depth = 1
351 except KeyError:
352 pass
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100353 bias_shape = np.asarray([ofm_depth])
354
355 return [ifm_shape, filter_shape, bias_shape]
356
357 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100358 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100359 pl, const = op["operands"]
360
361 if error_name != ErrorIf.WrongRank:
362 assert rank == 5
363
364 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100365 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000366 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100367
368 # Constrict the overall size of the shape when creating ERROR_IF tests
369 if error_name:
370 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
371 ifm_shape, max_dim=24, max_items=10000
372 )
373
374 # Get the filter depth/height/width from the operator parameters
375 filter_dhw = op["filter"]
376
377 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100378 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100379
380 # The filter dimensions are ODHWI
381 filter_shape = np.asarray(
382 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
383 )
384
385 # The bias is OC
386 bias_shape = np.asarray([ofm_channel])
387
388 return [ifm_shape, filter_shape, bias_shape]
389
390 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100391 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100392 pl, const = op["operands"]
393
394 if error_name != ErrorIf.WrongRank:
395 assert rank == 4
396
397 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100398 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000399 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100400
401 # Constrict the overall size of the shape when creating ERROR_IF tests
402 if error_name:
403 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
404 ifm_shape, max_dim=24, max_items=10000
405 )
406
407 # Get the filter height/width from the operator parameters
408 filter_hw = op["filter"]
409
410 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100411 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100412
413 # The filter dimensions are OHWI
414 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
415
416 # The bias is OC
417 bias_shape = np.asarray([ofm_depth])
418
419 return [ifm_shape, filter_shape, bias_shape]
420
421 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100422 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100423 pl, const = op["operands"]
424
425 if error_name != ErrorIf.WrongRank:
426 assert rank == 4
427 assert pl == 1 and const == 2
428
429 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100430 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000431 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100432
433 # Constrict the overall size of the shape when creating ERROR_IF tests
434 if error_name:
435 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
436 ifm_shape, max_dim=24, max_items=10000
437 )
438
439 # Get the filter height/width from the operator parameters
440 # Filter is KH, HW, C, M
441 filter_hw = op["filter"]
442
443 # Generate a random OFM depth, but don't let it get too big because
444 # the output depth is M * C
445 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100446 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100447 ) + 1
448
449 # The filter dimensions are HWCM
450 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
451
452 # The bias is M * C
453 bias_shape = np.asarray([ifm_shape[3] * filter_m])
454
455 return [ifm_shape, filter_shape, bias_shape]
456
457 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100458 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000459 pl, const = op["operands"]
460
461 if error_name != ErrorIf.WrongRank:
462 assert rank == 3
463 assert pl == 2 and const == 0
464
465 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100466 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000467
468 # Select nearest lower power of two from input height and width
469 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
470 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
471
472 # Constrict the overall size of the shape when creating ERROR_IF tests
473 if error_name:
474 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
475
476 # Generate an invalid kernel that is not a power of two
477 if error_name == ErrorIf.KernelNotPowerOfTwo:
478 inc_h = 2 if ifm_shape[1] == 1 else 1
479 inc_w = 2 if ifm_shape[2] == 1 else 1
480 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100481 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000482 ifm_shape[1] += selected_inc[0]
483 ifm_shape[2] += selected_inc[1]
484
485 ifm_shape = testGen.constrictBatchSize(ifm_shape)
486
487 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
488 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100489 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000490 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100491 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000492 ifm_shapes[modify_shape][modify_dim] *= 2
493
494 return [ifm_shapes[0], ifm_shapes[1]]
495
496 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100497 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000498 pl, const = op["operands"]
499
500 if error_name != ErrorIf.WrongRank:
501 assert rank == 3
502 assert pl == 1 and const == 0
503
504 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100505 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000506
507 # Select nearest lower power of two from input height and width
508 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
509 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
510
511 # Constrict the overall size of the shape when creating ERROR_IF tests
512 if error_name:
513 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
514
515 # Generate an invalid kernel that is not a power of two
516 if error_name == ErrorIf.KernelNotPowerOfTwo:
517 # We must increment by 2 if current size is 1
518 inc_h = 2 if ifm_shape[1] == 1 else 1
519 inc_w = 2 if ifm_shape[2] == 1 else 1
520 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100521 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000522 ifm_shape[1] += selected_inc[0]
523 ifm_shape[2] += selected_inc[1]
524
James Ward30124a82023-02-02 14:56:33 +0000525 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000526
527 return [ifm_shape]
528
529 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100530 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100531 pl, const = op["operands"]
532
533 if error_name != ErrorIf.WrongRank:
534 assert rank == 2
535
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100536 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100537
538 # Constrict the overall size of the shape when creating ERROR_IF tests
539 if error_name:
540 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
541
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100542 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100543 low=testGen.args.tensor_shape_range[0],
544 high=testGen.args.tensor_shape_range[1],
545 size=1,
546 )[0]
547 filter_shape = np.asarray([filter_oc, input_shape[1]])
548
549 bias_shape = np.asarray([filter_oc])
550
551 return [input_shape, filter_shape, bias_shape]
552
553 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100554 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100555 pl, const = op["operands"]
556
557 if error_name != ErrorIf.WrongRank:
558 assert rank == 3
559 assert pl == 2 and const == 0
560
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100561 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100562
563 # Constrict the overall size of the shape when creating ERROR_IF tests
564 if error_name:
565 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
566
567 # Get a random number for b_oc even if target shape is defined
568 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100569 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100570 low=testGen.args.tensor_shape_range[0],
571 high=testGen.args.tensor_shape_range[1],
572 size=1,
573 )
574 )[0]
575 # If N or H is large let b_oc be 1 to reduce output tensor size
576 if max(a_shape) > 1000:
577 b_oc = 1
578
579 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
580 return [a_shape, b_shape]
581
582 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100583 def tgConcat(testGen, rng, op, rank, error_name=None):
584 pl, const = op["operands"]
585 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100586
587 # Create extra tensors to concat.
588 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100589 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100590 shape_list = []
591 for i in range(pl + const + num_tensors):
592 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100593 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100594 wrongShape = shape.copy()
595
596 if remove and len(shape) > 1:
597 wrongShape = wrongShape[1:]
598 else:
599 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100600 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100601
602 shape_list.append(wrongShape)
603 else:
604 shape_list.append(shape.copy())
605
606 return shape_list
607
608 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100609 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100610 if error_name in [
611 ErrorIf.AxisSmallerZero,
612 ErrorIf.AxisLargerRank,
613 ErrorIf.ConcatInputRankMismatch,
614 ]:
615 return shapeList
616
617 # Split concat shape along axis to allow for multiple const inputs
618 # without making too many large tensors
619 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
620 # If axis can't be split we still need to invalidate other dimensions
621 if error_name == ErrorIf.ConcatInputDimMismatch:
622 for shape in shapeList[1:]:
623 # Negative test shapeLists are created individually for each test,
624 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100625 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100626 return shapeList
627
628 # Create copy of shape we are going to split (so we don't alter shapeList)
629 shape = shapeList[0].copy()
630 # Add original shape as first input
631 new_shapeList = [shape.copy()]
632 length_on_axis = shape[axis]
633 remaining_length = length_on_axis
634 for i in range(len(shapeList) - 2):
635 # Calculate split on axis and remaining value
636 split_shape_val = int(shape[axis] / 2)
637 remaining_length = remaining_length - split_shape_val
638
639 # Append new shape, and set remaining shape
640 shape[axis] = split_shape_val
641 new_shapeList.append(shape.copy())
642
643 # invalidate dimensions
644 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100645 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100646 else:
647 shape[axis] = remaining_length
648
649 if i == len(shapeList) - 3:
650 new_shapeList.append(shape.copy())
651
652 return new_shapeList
653
654
655class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100656 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100657
658 def __init__(self):
659 pass
660
Jeremy Johnson1271c442023-09-05 11:39:26 +0100661 class TVGInfo:
662 """Enhanced tensor values information including data gen dict."""
663
664 def __init__(self, tensorList, dataGenDict):
665 self.tensorList = tensorList
666 self.dataGenDict = dataGenDict
667
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100668 # Default high value for random numbers
669 TVG_FLOAT_HIGH_VALUE = {
670 DType.FP32: (1 << 128) - (1 << (127 - 23)),
671 DType.FP16: (1 << 16) - (1 << (15 - 10)),
672 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000673 DType.FP8E4M3: 448,
674 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100675 }
676
Jeremy Johnson30476252023-11-20 16:15:30 +0000677 # Default lowest normal values for random numbers
678 TVG_FLOAT_LOW_VALUE = {
679 DType.FP32: np.exp2(-126),
680 DType.FP16: np.exp2(-14),
681 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000682 DType.FP8E4M3: np.exp2(-9),
683 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000684 }
685
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100686 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100687 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000688 # Return a tuple of (low,high) data range values for the given data
689 # type using a combination of per operator table limits, data limits
690 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000691 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100692 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000693 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000694 if lowValueLookup is not None and dtype in lowValueLookup:
695 low_val = lowValueLookup[dtype]
696 else:
697 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000698 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000699 # respecting the default ranges if more/less than the low/high
700 # values
701 data_range = (
702 max(low_val, type_range[0]),
703 min(high_val, type_range[1]),
704 )
705 if data_range[0] > data_range[1]:
706 # Invalid data range from low to high created due to user
707 # constraints revert to using internal ranges as they are
708 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000709 logger.info(
710 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
711 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000712 data_range = (low_val, high_val)
713 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000714 return None
715
716 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100717 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100718 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100719 ):
720 # Variable inputs versus constants
721 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000722 if "p_count" in argsDict:
723 # Override for operators like CONCAT
724 pCount = argsDict["p_count"]
725 cCount = argsDict["c_count"]
726 assert pCount + cCount == len(
727 shapeList
728 ), "Placeholders & Constant tensors must match shapes list"
729
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000730 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100731
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100732 if (
733 error_name is not None
734 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100735 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100736 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000737 # Fall back to internal data gen when dealing with unsupported types or ops
738 data_range = argsDict["data_range"] if "data_range" in argsDict else None
739 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000740 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000741 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000742 if "data_range_list" in argsDict:
743 data_range = argsDict["data_range_list"][idx]["range"]
744 roundMode = (
745 "round" in argsDict["data_range_list"][idx]
746 and argsDict["data_range_list"][idx]["round"] is True
747 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000748 if data_range is not None and dtype not in (
749 DType.FP16,
750 DType.FP32,
751 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000752 DType.FP8E4M3,
753 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000754 ):
755 # Change from inclusive to exclusive range
756 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000757
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100758 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000759 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000760 if dtype == DType.SHAPE:
761 arr = np.int64(argsDict["fixed_data"][idx])
762 elif dtype == DType.INT8:
763 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000764 elif dtype == DType.INT16:
765 arr = np.int16(argsDict["fixed_data"][idx])
766 elif dtype == DType.INT32:
767 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000768 else:
769 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000770 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100771 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000772 if roundMode:
773 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000774 if idx < pCount:
775 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
776 else:
777 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100778
Jeremy Johnson1271c442023-09-05 11:39:26 +0100779 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
780
781 # Create data generator meta-data
782 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100783 tens_data = {
784 "version": "0.1",
785 "tensors": {},
786 }
787 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100788 for idx, shape in enumerate(shapeList):
789
790 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000791 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
792 tens_meta["generator"] = gtu.DataGenType(
793 gtu.DataGenType.FIXED_DATA
794 ).name
795 else:
796 tens_meta["generator"] = gtu.DataGenType(dg_type).name
797
Jeremy Johnson1271c442023-09-05 11:39:26 +0100798 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
799 tens_meta["shape"] = [int(i) for i in shape]
800 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100801 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100802
Jeremy Johnsonc870d1e2024-04-08 16:17:47 +0100803 if testGen.args.random_const_inputs:
804 # Choose type of tensor biased by defaults
805 percentage = rng.randInt(0, 100)
806 variable = (idx < pCount and percentage < 70) or (
807 idx >= pCount and percentage >= 70
808 )
809 else:
810 # Use default set up of constants versus inputs for the op
811 variable = idx < pCount
812
813 if variable:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100814 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100815 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100816 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100817
818 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
819 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000820 if (
821 tens_meta["generator"]
822 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
823 ):
824 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
825 tens_meta["fixed_data_info"] = info
826 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100827 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000828
Won Jeon64e4bfe2024-01-18 06:31:55 +0000829 data_range = None
830 if "data_range_list" in argsDict:
831 data_range = argsDict["data_range_list"][idx]["range"]
832 if "round" in argsDict["data_range_list"][idx]:
833 info["round"] = argsDict["data_range_list"][idx]["round"]
834 elif "data_range" in argsDict:
835 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000836
Won Jeon64e4bfe2024-01-18 06:31:55 +0000837 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100838 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000839 info["range"] = [str(v) for v in data_range]
840 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100841 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
842 info = {}
843 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100844 info["ks"] = int(argsDict["ks"])
845 if "acc_type" in argsDict:
846 # Convert type number into JSON name
847 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
848 "json"
849 ]
850 if "kernel" in argsDict:
851 info["kernel"] = [int(k) for k in argsDict["kernel"]]
852 if "axis" in argsDict:
853 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100854 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000855 elif dg_type == gtu.DataGenType.FULL_RANGE:
856 info = {}
857 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100858 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000859 )
860 tens_meta["full_range_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100861 else:
862 # TODO - other data gen type
863 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100864
865 # Using the finished generate config meta data - generate the data if
866 # needed and assign a tensor name from the serializer
867
868 # Need to generate data when not lazy or for the bias tensor as we need
869 # to work out if the bias data is non-zero for compliance
870 if not testGen.args.lazy_data_gen or (
871 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
872 ):
873 # Give this tensor a temporary name until we get one from the serializer
874 temp_name = f"placeholder_{idx}"
875 dg_tens_meta[temp_name] = tens_meta
876 # Create data now using the temporary name to access meta details
877 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000878 if tens_meta["data_type"] == "SHAPE":
879 # Tensor type SHAPE and Numpy file type must be the same
880 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100881 # Remove the item as we will give it the correct name later
882 del dg_tens_meta[temp_name]
883
884 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
885 # The KS value used by compliance verification is altered when the
886 # bias data is non-zero
887 if max(abs(data)) > 0.0:
888 argsDict["ksb"] = argsDict["ks"] + 1
889
890 if testGen.args.lazy_data_gen:
891 data = None
892
Jeremy Johnsonc870d1e2024-04-08 16:17:47 +0100893 if variable:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100894 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
895 else:
896 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
897
898 tens_ser_list.append(tens)
899 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100900 dg_tens_meta[tens.name] = tens_meta
901
Jeremy Johnson1271c442023-09-05 11:39:26 +0100902 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
903
904 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100905 def tvgNegate(
906 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
907 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100908 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000909 # Integer test
910 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100911 pCount, cCount = op["operands"]
912 assert (
913 pCount == 1 and cCount == 0
914 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100915 # Must create tensors with values within accumulator (int32) negatable
916 # range
917 max_val = (1 << 31) - 1
918 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100919 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100920 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100921 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000922 tens_ser_list = []
923 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100924 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
925 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000926 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100927 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000928 # ERROR_IF or floating point test
929 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100930 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100931 )
932
Jeremy Johnson30476252023-11-20 16:15:30 +0000933 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000934 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
935 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
936 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
937 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
938 }
939
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100940 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100941 def tvgAddSub(
942 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
943 ):
Won Jeon74342e52024-01-09 00:34:40 +0000944 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000945 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100946 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000947 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100948 pCount, cCount = op["operands"]
949 assert (
950 pCount == 2 and cCount == 0
951 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000952 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000953 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
Jeremy Johnson32bf9012024-03-20 16:32:23 +0000954 data_range = None # Use default
955 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE):
956 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100957 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
958 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100959 if add:
960 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
961 else:
962 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
963
964 # Work out the saturation limits
965 max_i32 = (1 << 31) - 1
966 min_i32 = -(1 << 31)
967 max_arr = np.full(shapeList[1], max_i32)
968 min_arr = np.full(shapeList[1], min_i32)
969
970 # Find how much values exceed the maximum/minimums
971 sat_max_arr = np.maximum(res_arr - max_arr, 0)
972 sat_min_arr = np.minimum(res_arr - min_arr, 0)
973
974 if not add:
975 # Swap saturation values and negate values as we need to perform opposite operations
976 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
977
978 # Create new array of unsaturated values by clipping values as needed
979 b_unsat_arr = b_arr
980 if (sat_max_arr != 0).any():
981 # Clip values that cause saturation
982 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
983 # Reduce axes in unsaturated tensor to match original tensor
984 for axis, dim in enumerate(b_arr.shape):
985 if dim != b_unsat_arr.shape[axis]:
986 assert (
987 dim == 1
988 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
989 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
990
991 if (sat_min_arr != 0).any():
992 # Clip values that cause saturation
993 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
994 # Reduce axes in unsaturated tensor to match original tensor
995 for axis, dim in enumerate(b_arr.shape):
996 if dim != b_unsat_arr.shape[axis]:
997 assert (
998 dim == 1
999 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
1000 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
1001
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001002 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001003 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1004 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001005 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001006 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
1007 )
1008
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001009 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001010 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001011 # ERROR_IF or floating point test
1012 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001013 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001014 )
1015 if data_range:
1016 argsDict["data_range"] = data_range
1017
1018 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001019 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001020 )
1021
1022 @staticmethod
1023 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001024 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001025 ):
1026 if dtypeList[0] in (
1027 DType.INT32,
1028 DType.INT16,
1029 DType.INT8,
1030 ):
1031 # Limit input tensors with cond_if_binary or while_loop to stop
1032 # saturation of add/sub ops with int32 and keep all logical shift
1033 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001034 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001035 pCount, cCount = op["operands"]
1036 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001037 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001038 for idx, shape in enumerate(shapeList[:]):
1039 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001040 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001041 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001042 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001043 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001044 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001045 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1046 )
1047 pRemain -= 1
1048 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001049 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001050 testGen.ser.addConst(shape, dtypeList[idx], arr)
1051 )
1052
Jeremy Johnson587cc842024-02-08 11:45:44 +00001053 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001054 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001055 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001056 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 )
1058
1059 @staticmethod
1060 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001061 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001062 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001063 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064 pCount, cCount = op["operands"]
1065 # Force value of operand[1] to be within [0, num_bits]
1066 assert (
1067 pCount == 2 and cCount == 0
1068 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1069
Jeremy Johnson587cc842024-02-08 11:45:44 +00001070 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001071 for idx, shape in enumerate(shapeList[:]):
1072 if idx == 1:
1073 if dtypeList[idx] == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001074 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001075 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001076 arr = np.int32(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001077 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001078 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001079 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001080 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001081 else:
1082 raise Exception("OpArithmeticRightShift: invalid input dtype")
1083 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001084 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001085 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001086
Jeremy Johnson587cc842024-02-08 11:45:44 +00001087 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001088
1089 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001090 def tvgReshape(
1091 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1092 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001093 dtypeList[1] = DType.SHAPE
1094 shapeList[1] = [len(argsDict["new_shape"])]
1095 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1096 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1097
1098 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001099 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001100 )
1101
1102 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001103 def tvgRescale(
1104 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1105 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001106 scale32 = argsDict["scale"]
1107 multiplier_arr = argsDict["multiplier"]
1108 shift_arr = argsDict["shift"]
1109
1110 if scale32:
1111 dtypeList[1] = DType.INT32
1112 else:
1113 dtypeList[1] = DType.INT16
1114 shapeList[1] = [len(multiplier_arr)]
1115 dtypeList[2] = DType.INT8
1116 shapeList[2] = [len(shift_arr)]
1117 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1118 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1119
1120 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001121 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001122 )
1123
1124 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001125 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001126 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1127 pad_values = argsDict["pad"].flatten()
1128 dtypeList[1] = DType.SHAPE
1129 shapeList[1] = [len(pad_values)]
1130 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1131 argsDict["fixed_data"] = [None, pad_values]
1132
1133 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001134 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001135 )
1136
1137 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001138 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001139 dtypeList[1] = DType.SHAPE
1140 shapeList[1] = [len(argsDict["start"])]
1141 dtypeList[2] = DType.SHAPE
1142 shapeList[2] = [len(argsDict["size"])]
1143 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1144 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1145
1146 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001147 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001148 )
1149
1150 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001151 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001152 dtypeList[1] = DType.SHAPE
1153 shapeList[1] = [len(argsDict["multiples"])]
1154 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1155
1156 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001157 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001158 )
1159
1160 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001161 def tvgSelect(
1162 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1163 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001164 # Set datatype of condition tensor to boolean
1165 dtypeList[0] = DType.BOOL
1166
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001167 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001168 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001169 )
1170
1171 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001172 def tvgIntDiv(
1173 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1174 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001175 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001176 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001177 pCount, cCount = op["operands"]
1178 assert (
1179 pCount == 2 and cCount == 0
1180 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1181
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001182 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001183
1184 # Two invalid cases for Op.INTDIV:
1185 # 1. divisor == 0
1186 # 2. dividend == -(1<<31) and divisor == -1
1187 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001188 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1189 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001190
1191 if (divisor_arr == 0).any():
1192 continue
1193
1194 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1195 continue
1196
1197 break
1198
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001199 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001200 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1201 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001202 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001203 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1204 )
1205
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001206 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001207 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001208 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001209 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001210 )
1211
Jeremy Johnson30476252023-11-20 16:15:30 +00001212 # Set the MUL data range to the square root of the largest value
1213 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001214 TVG_FLOAT_HIGH_VALUE_MUL = {
1215 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1216 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1217 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1218 }
1219
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001220 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001221 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001222 if error_name is not None or dtypeList[0] in (
1223 DType.FP16,
1224 DType.BF16,
1225 DType.FP32,
1226 ):
1227 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001228 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001229 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001230 )
1231 if data_range:
1232 argsDict["data_range"] = data_range
1233
Jeremy Johnson0a042992024-02-28 13:20:05 +00001234 if dtypeList[0] != DType.SHAPE:
1235 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1236 dtypeList[2] = DType.INT8
1237 shapeList[2] = [1]
1238 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1239 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1240
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001241 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001242 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001243 )
1244 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001245 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001246 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001247
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001248 tens_ser_list = []
1249
1250 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001251 if dtypeList[0] == DType.SHAPE:
1252 shift = 0
1253 else:
1254 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001255 if dtypeList[0] == DType.INT8:
1256 num_bits = 8
1257 elif dtypeList[0] == DType.INT16:
1258 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001259 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001260 num_bits = 32
1261 elif error_name == ErrorIf.WrongInputType:
1262 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001263 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001264 raise Exception(
1265 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1266 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001267
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001268 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001269 if dtypeList[idx] == DType.SHAPE:
1270 low = testGen.args.tensor_shape_range[0]
1271 high = testGen.args.tensor_shape_range[1]
1272 else:
1273 low = -(2 ** (num_bits - 1))
1274 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001275
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001276 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1277 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001278
1279 i = 0
1280 while True:
1281
1282 a_arr_64 = a_arr.astype(np.int64)
1283 b_arr_64 = b_arr.astype(np.int64)
1284
1285 if shift > 0:
1286 rounding = 1 << (shift - 1)
1287 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001288 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001289 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001290
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001291 if (result_arr > -(2**31)).all() and (
1292 result_arr <= ((2**31) - 1)
1293 ).all():
1294 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001295
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001296 i = i + 1
1297 a_arr = a_arr // 2
1298 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001299
Won Jeon74342e52024-01-09 00:34:40 +00001300 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001301 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001302 tens_ser_list.append(
1303 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1304 )
1305 tens_ser_list.append(
1306 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1307 )
1308 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001309 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001310 tens_ser_list.append(
Jeremy Johnson18a379d2024-03-28 15:53:21 +00001311 testGen.ser.addPlaceholder(
1312 shapeList[0], dtypeList[0], a_arr.astype(np.int32)
1313 )
Won Jeon74342e52024-01-09 00:34:40 +00001314 )
1315 tens_ser_list.append(
Jeremy Johnson18a379d2024-03-28 15:53:21 +00001316 testGen.ser.addPlaceholder(
1317 shapeList[1], dtypeList[1], b_arr.astype(np.int32)
1318 )
Won Jeon74342e52024-01-09 00:34:40 +00001319 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001320 tens_ser_list.append(
1321 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1322 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001323
1324 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001325
1326 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001327 def tvgConcat(
1328 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1329 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001330 count = len(shapeList) - testGen.args.num_const_inputs_concat
1331 if count < 1:
1332 count = 1
1333 if testGen.args.num_const_inputs_concat == 0:
1334 count = len(shapeList)
1335
Won Jeon74342e52024-01-09 00:34:40 +00001336 op = testGen.TOSA_OP_LIST[opName]
1337 if op["op"] == Op.CONCAT_SHAPE:
1338 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001339 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001340 else:
1341 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001342 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001343 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001344
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001345 # Override default pCount/cCount for operator
1346 argsDict["p_count"] = count
1347 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001348
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001349 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001350 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001351 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001352
1353 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001354 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001355 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001356 ):
1357 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001358 pCount, cCount = op["operands"]
1359 assert (
1360 pCount == 2 and cCount == 0
1361 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001362 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1363 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001364 tens_ser_list = []
1365 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001366 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1367 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001368 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001369 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1370 )
1371
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001372 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001373
1374 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001375 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001376 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1377 # Integer
1378 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001379 pCount, cCount = op["operands"]
1380 assert (
1381 pCount == 2 and cCount == 0
1382 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001383
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001384 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1385 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001386
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001387 # Using random numbers means that it will be very unlikely that
1388 # there are any matching (equal) values, therefore force that
1389 # there are twice the number of matching values as the tensor rank
1390 for num in range(0, len(shapeList[0]) * 2):
1391 a_index = []
1392 b_index = []
1393 # Choose an index in each axis for the whole shape
1394 for axis in range(0, len(shapeList[0])):
1395 # Index can be up to the largest dimension in both shapes
1396 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001397 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001398 )
1399 # Reduce the index down to a shape's dim for broadcasting
1400 a_index.append(min(shapeList[0][axis] - 1, index))
1401 b_index.append(min(shapeList[1][axis] - 1, index))
1402
1403 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1404
Jeremy Johnsona0150012023-11-15 15:52:06 +00001405 tens_ser_list = []
1406 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001407 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1408 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001409 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001410 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1411 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001412 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001413 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001414 # ERROR_IF or floating point test
1415 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001416 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001417 )
1418
1419 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001420 def tvgReduceSum(
1421 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1422 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001423 dtype = dtypeList[0]
1424 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001425 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001426 pCount, cCount = op["operands"]
1427 assert (
1428 pCount == 1 and cCount == 0
1429 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1430 # Limit values so that the sum cannot exceed the range of an int32 during
1431 # summation of any axis
1432 range_val = int((1 << 31) / max(shapeList[0]))
1433 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001434 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001435 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001436 tens_ser_list = []
1437 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001438 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001439 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001440 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001441 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001442 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001443 if (
1444 error_name is None
1445 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1446 ):
1447 # Limit ranges for (non error & non compliance) tests by using
1448 # values that can be summed on any axis to not hit infinity
1449 highval_lookup = {
1450 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1451 / max(shapeList[0])
1452 }
1453 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001454 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001455 )
1456 assert data_range is not None
1457 argsDict["data_range"] = data_range
1458
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001459 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001460 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001461 )
1462
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001463 @staticmethod
1464 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001465 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001466 ):
1467 dtype = dtypeList[0]
1468 if error_name is None:
1469 # Limit ranges for (non error) tests by using
1470 # values that can be multiplied on any axis to not hit infinity
1471 highval_lookup = {
1472 dtype: math.pow(
1473 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1474 1 / max(shapeList[0]),
1475 )
1476 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001477 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001478 assert data_range is not None
1479 argsDict["data_range"] = data_range
1480
1481 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001482 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001483 )
1484
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001485 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001486 def tvgResize(
1487 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1488 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001489 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001490 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001491 dtypeList[0],
1492 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1493 )
1494 if data_range:
1495 argsDict["data_range"] = data_range
1496 # Needed for compliance
1497 argsDict["max_abs_value"] = data_range[1]
1498
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001499 scale_values = argsDict["scale"]
1500 offset_values = argsDict["offset"]
1501 border_values = argsDict["border"]
1502 dtypeList[1] = DType.SHAPE
1503 dtypeList[2] = DType.SHAPE
1504 dtypeList[3] = DType.SHAPE
1505 shapeList[1] = [len(scale_values)]
1506 shapeList[2] = [len(offset_values)]
1507 shapeList[3] = [len(border_values)]
1508 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1509
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001510 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001511 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001512 )
1513
Jeremy Johnson30476252023-11-20 16:15:30 +00001514 # Set the POW exponent high data range
1515 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1516 DType.FP32: 10.0,
1517 DType.FP16: 10.0,
1518 DType.BF16: 10.0,
1519 }
1520 # POW highest base value (within a safe margin of error) that can be raised
1521 # to +ve exponent that doesn't become Infinity
1522 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1523 DType.FP32: math.floor(
1524 math.pow(
1525 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1526 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1527 )
1528 ),
1529 DType.FP16: math.floor(
1530 math.pow(
1531 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1532 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1533 )
1534 ),
1535 DType.BF16: math.floor(
1536 math.pow(
1537 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1538 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1539 )
1540 ),
1541 }
1542 # POW lowest base value (within a safe margin of error) that can be raised
1543 # to -ve exponent that doesn't become Infinity
1544 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1545 DType.FP32: math.ceil(
1546 math.pow(
1547 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1548 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1549 )
1550 * 1000
1551 )
1552 / 1000,
1553 DType.FP16: math.ceil(
1554 math.pow(
1555 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1556 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1557 )
1558 * 1000
1559 )
1560 / 1000,
1561 DType.BF16: math.ceil(
1562 math.pow(
1563 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1564 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1565 )
1566 * 1000
1567 )
1568 / 1000,
1569 }
1570
1571 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001572 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001573 if error_name is not None:
1574 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001575 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001576 )
1577 dtype = dtypeList[0]
1578 # Different ranges for POW
1579 test_set = argsDict["s"]
1580 if test_set == 0:
1581 # Positive base with fractional exponent
1582 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001583 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001584 dtype,
1585 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1586 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1587 )
1588 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001589 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001590 )
1591 exp_round = False
1592 else:
1593 # Integer exponent
1594 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001595 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001596 )
1597 exp_round = True
1598 if test_set == 1:
1599 # Positive base
1600 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001601 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001602 dtype,
1603 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1604 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1605 )
1606 else:
1607 assert test_set == 2
1608 # Negative base
1609 # Supply new look up tables with negative values
1610 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001611 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001612 dtype,
1613 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1614 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1615 )
1616
1617 data_range_list = (
1618 {
1619 "range": base_range,
1620 },
1621 {
1622 "range": exp_range,
1623 "round": exp_round,
1624 },
1625 )
1626 argsDict["data_range_list"] = data_range_list
1627 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001628 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001629 )
1630
1631 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001632 def tvgLogRsqrt(
1633 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1634 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001635 # LOG & RSQRT data range from lowest expressible positive number to
1636 # largest to avoid NaNs
1637 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001638 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001639 dtypeList[0],
1640 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1641 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1642 )
1643 if data_range:
1644 argsDict["data_range"] = data_range
1645
1646 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001647 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001648 )
1649
1650 # Set the EXP data range to the log of the largest to smallest values
1651 # to avoid infinities or making the result zero
1652 TVG_FLOAT_HIGH_VALUE_EXP = {
1653 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1654 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1655 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1656 }
1657 TVG_FLOAT_LOW_VALUE_EXP = {
1658 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1659 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1660 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1661 }
1662
1663 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001664 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001665 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001666 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001667 dtypeList[0],
1668 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1669 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1670 )
1671 if data_range:
1672 argsDict["data_range"] = data_range
1673
1674 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001675 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001676 )
1677
1678 @staticmethod
1679 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001680 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001681 ):
1682 dtype = dtypeList[0]
1683 if (
1684 error_name is None
1685 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001686 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001687 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001688 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001689 # Limit ranges for (non error & non compliance) FP tests by using
1690 # values that can be multiplied on any axis to not hit infinity/NaN
1691 IC = shapeList[0][1]
1692 highval_lookup = {
1693 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1694 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001695 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001696 assert data_range is not None
1697 argsDict["data_range"] = data_range
1698
1699 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001700 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001701 )
1702
Jeremy Johnson708da822023-11-15 16:25:45 +00001703 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001704 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001705 in_dtype = dtypeList[0]
1706 out_dtype = argsDict["out_type"]
1707 # Create look up to limit input tensor to output type maximums to avoid
1708 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001709 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001710 highval_lookup = {in_dtype: out_range[1]}
1711 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001712 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001713 in_dtype,
1714 highval_lookup,
1715 )
1716
1717 assert data_range is not None
1718 argsDict["data_range"] = data_range
1719
1720 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001721 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001722 )
1723
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001724 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001725 def tvgGather(
1726 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1727 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001728 K = shapeList[0][1]
1729
1730 # Fix the type of the indices tensor
1731 dtypeList[1] = DType.INT32
1732
1733 dtype = dtypeList[0]
1734 if not gtu.dtypeIsSupportedByCompliance(dtype):
1735 # Test unsupported by data generator
1736 op = testGen.TOSA_OP_LIST[opName]
1737 pCount, cCount = op["operands"]
1738 assert (
1739 pCount == 2 and cCount == 0
1740 ), "Op.GATHER must have 2 placeholders, 0 consts"
1741
1742 tens_ser_list = []
1743 for idx, shape in enumerate(shapeList):
1744 dtype = dtypeList[idx]
1745 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001746 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001747 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1748 else:
1749 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001750 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001751 # To match old functionality - create indices as CONST
1752 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1753
1754 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1755
1756 else:
1757 # ERROR_IF or floating point test
1758 # Use inclusive values upto index K for indices tensor
1759 data_range_list = (
1760 {"range": None},
1761 {"range": (0, K - 1)},
1762 )
1763 argsDict["data_range_list"] = data_range_list
1764
1765 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001766 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001767 )
1768
1769 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001770 def tvgScatter(
1771 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1772 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001773 K = shapeList[0][1]
1774 W = shapeList[2][1]
1775
1776 # Work out an indices tensor here with data that doesn't exceed the
1777 # dimension K of the values_in tensor and does NOT repeat the same K
1778 # location as needed by the spec:
1779 # "It is not permitted to repeat the same output index within a single
1780 # SCATTER operation and so each output index occurs at most once."
1781 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1782
1783 # Fix the type of the indices tensor
1784 dtypeList[1] = DType.INT32
1785
1786 dtype = dtypeList[0]
1787 if not gtu.dtypeIsSupportedByCompliance(dtype):
1788 # Test unsupported by data generator
1789 op = testGen.TOSA_OP_LIST[opName]
1790 pCount, cCount = op["operands"]
1791 assert (
1792 pCount == 3 and cCount == 0
1793 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1794
1795 tens_ser_list = []
1796 for idx, shape in enumerate(shapeList):
1797 dtype = dtypeList[idx]
1798 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001799 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001800 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1801 else:
1802 # Create the indices array
1803 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1804 arr = []
1805 for n in range(shape[0]):
1806 # Get a shuffled list of output indices (0 to K-1) and
1807 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001808 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001809 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1810 # To match old functionality - create indices as CONST
1811 tens_ser_list.append(
1812 testGen.ser.addConst(shape, dtype, indices_arr)
1813 )
1814
1815 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1816
1817 else:
1818 # ERROR_IF or floating point test
1819 # Use inclusive values upto index K for indices tensor
1820 data_range_list = (
1821 {"range": None},
1822 {"range": (0, K - 1)},
1823 {"range": None},
1824 )
1825 argsDict["data_range_list"] = data_range_list
1826
1827 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001828 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001829 )
1830
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001831
1832class TosaArgGen:
1833 """Argument generators create exhaustive or random lists of attributes for
1834 operators that take attributes or other parameters.
1835
1836 The return value is a list of (descriptive_name, [arglist]) tuples where
1837 the descriptive_name is appended to the test name and the arglist is expanded
1838 as arguments to the operator build function.
1839 """
1840
1841 def __init__(self):
1842 pass
1843
1844 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001845 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001846 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001847 if (
1848 error_name is None
1849 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1850 and gtu.dtypeIsSupportedByCompliance(dtype)
1851 ):
evacha01ad8e1e22024-03-19 12:42:17 +00001852 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"].get(
1853 dtype, (gtu.DataGenType.PSEUDO_RANDOM,)
1854 )
1855
Jeremy Johnson1271c442023-09-05 11:39:26 +01001856 else:
1857 # Error test or No data generator types listed - assume random
1858 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1859
1860 # Expand arg list with other data generator types
1861 new_arg_list = []
1862 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001863 for arg_str, args_dict in arg_list:
evacha01ad8e1e22024-03-19 12:42:17 +00001864 gen_args_dict = args_dict.copy()
Jeremy Johnson1271c442023-09-05 11:39:26 +01001865 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001866 if error_name is None:
1867 num_test_sets = (
1868 args_dict["num_test_sets"]
1869 if "num_test_sets" in args_dict
1870 else 0
1871 )
1872 else:
evacha019c96eef2024-02-07 11:21:55 +00001873 # Add single test for pseudo random
Jeremy Johnson30476252023-11-20 16:15:30 +00001874 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001875
1876 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1877 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001878 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001879 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001880 shape_info = (
1881 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1882 if "shape" in args_dict
1883 else ""
1884 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001885 logger.info(
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00001886 f"Skipping {opName}{shape_info} {gtu.DTYPE_ATTRIBUTES[dtype]['json']} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001887 )
1888 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001889 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001890 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001891 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001892
Jeremy Johnson30476252023-11-20 16:15:30 +00001893 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1894
evacha01ad8e1e22024-03-19 12:42:17 +00001895 elif dg_type == gtu.DataGenType.FULL_RANGE:
1896 tensor_size = gtu.product(shapeList[0])
1897 if tensor_size < gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1898 shape_info = " ({})".format(shapeList[0])
1899 logger.info(
1900 f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}"
1901 )
1902 continue
1903 # Large enough tensor data size for full range, add a single test
1904 num_test_sets = 0
1905 arg_str = f"{arg_str}_full" if arg_str else "full"
1906 gen_args_dict["tags"] = args_dict.get("tags", []) + [
1907 "non_finite_fp_data"
1908 ]
1909
1910 gen_args_dict["dg_type"] = dg_type
Jeremy Johnson30476252023-11-20 16:15:30 +00001911 if num_test_sets > 0:
1912 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001913 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
evacha01ad8e1e22024-03-19 12:42:17 +00001914 set_args_dict = gen_args_dict.copy()
evacha019c96eef2024-02-07 11:21:55 +00001915 set_args_dict["s"] = s
evacha019c96eef2024-02-07 11:21:55 +00001916 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001917 else:
1918 # Default is a single test
evacha01ad8e1e22024-03-19 12:42:17 +00001919 new_arg_list.append((arg_str, gen_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001920
1921 return new_arg_list
1922
1923 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001924 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001925 """A trivial argument generator for operators that don't take any
1926 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001927 arg_list = TosaArgGen._add_data_generators(
1928 testGen,
1929 opName,
evacha019c96eef2024-02-07 11:21:55 +00001930 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001931 dtype,
1932 [("", {})],
1933 error_name,
1934 )
1935 # Return list of tuples: (arg_str, args_dict)
1936 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001937
1938 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001939 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001940 """Pow operator needs different test sets to cover random numbers
1941 without creating NaNs or Infs"""
1942 arg_list = TosaArgGen._add_data_generators(
1943 testGen,
1944 opName,
evacha019c96eef2024-02-07 11:21:55 +00001945 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001946 dtype,
1947 [("", {"num_test_sets": 3})],
1948 error_name,
1949 )
1950 # Return list of tuples: (arg_str, args_dict)
1951 return arg_list
1952
1953 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001954 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001955 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001956 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001957 shape = shapeList[0]
1958
1959 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001960 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001961 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001962 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001963 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001964 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001965 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001966 # Create tests for each dimension
1967 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001968
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001969 opid = testGen.TOSA_OP_LIST[opName]["op"]
1970
1971 for a in axes:
1972 args_dict = {"axis": int(a)}
1973 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001974 output_shape = shape.copy()
1975 if error_name is None:
1976 # It only matters that we calculate the dot_products correctly
1977 # for non error_if tests as they should never be run
1978 output_shape[a] = 1
1979 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001980 args_dict["shape"] = shape
1981 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1982 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1983
1984 arg_list.append(("axis{}".format(a), args_dict))
1985
1986 arg_list = TosaArgGen._add_data_generators(
1987 testGen,
1988 opName,
evacha019c96eef2024-02-07 11:21:55 +00001989 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001990 dtype,
1991 arg_list,
1992 error_name,
1993 )
1994 # Return list of tuples: (arg_str, args_dict)
1995 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001996
1997 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001998 def _calculate_sparsity(num_tests, sparsity_factor):
1999 sparsity = num_tests // sparsity_factor + 1
2000 # If there are only a small number of tests, just select them all
2001 if sparsity < 13:
2002 sparsity = 1
2003 # To get a variety of parameter combinations sparsity should not be a
2004 # multiple of 2, 3 or 5
2005 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2006 sparsity += 1
2007 return sparsity
2008
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002009 # Maximum number of error_if variants to produce
Jeremy Johnson87460262024-03-25 09:46:02 +00002010 MAX_TESTS_ERROR_IFS = 3
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002011
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002012 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002013 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002014 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002015 arg_list = []
2016
Jeremy Johnson0c716862023-04-13 17:18:19 +01002017 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002018 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002019 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002020 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002021
Tai Lyf36f2562024-03-14 16:21:29 +00002022 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2023
2024 if error_name == ErrorIf.WrongAccumulatorType:
2025 accum_dtypes = (
2026 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2027 )
James Ward8b390432022-08-12 20:48:56 +01002028
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002029 # For op type checks
2030 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002031
2032 # Check the rank
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002033 rank = 5 if op["op"] == Op.CONV3D else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002034 if error_name != ErrorIf.WrongRank:
2035 assert len(ifm_shape) == rank
2036 assert len(filter_shape) == rank
2037
Jeremy Johnson0c716862023-04-13 17:18:19 +01002038 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002039 k_rank = rank - 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002040 k_pos = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002041 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002042 # compliance size - KS
2043 k_size = gtu.product(k_shape)
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002044 if not op["op"] == Op.DEPTHWISE_CONV2D:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002045 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002046
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002047 def get_conv_output_info(p, s, d, fix_up_padding=False):
2048 # Work out remainders and output dimensions with an
2049 # option to adjust paddings to create a valid operation
2050 nonlocal ifm_shape, k_shape, error_name, k_rank
2051 if fix_up_padding:
2052 p = list(p) # Make paddings editable
2053 outputs_no_stride = []
2054 remainders = []
2055 outputs = []
2056 for index in range(k_rank):
2057 pad_offset = index * 2
2058 fixed = False
2059 # Fix up pad values to produce valid conv2d
2060 while not fixed:
2061 # Output dimension without being adjusted for stride
2062 output_no_stride = (
2063 ifm_shape[index + 1]
2064 - 1
2065 + p[pad_offset]
2066 + p[pad_offset + 1]
2067 - (k_shape[index] - 1) * d[index]
2068 )
2069 # Tensor left over after applying striding
2070 remainder = output_no_stride % s[index]
2071 if not fix_up_padding:
2072 # Just want remainders and outputs
2073 break
2074 if output_no_stride <= 0:
2075 p[pad_offset + 1] += abs(output_no_stride) + 1
2076 continue
2077 if error_name == ErrorIf.ConvOutputShapeNonInteger:
2078 if remainder:
2079 # Conditions to trigger the test
2080 fixed = True
2081 else:
2082 p[pad_offset + 1] += 1
2083 else:
2084 if remainder:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002085 # Stride will be negative for StrideSmallerOne
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002086 assert remainder > 0 or (
2087 error_name == ErrorIf.StrideSmallerOne and remainder < 0
2088 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002089 p[pad_offset + 1] += abs(remainder)
2090 else:
2091 fixed = True
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002092 outputs_no_stride.append(output_no_stride)
2093 remainders.append(remainder)
2094 # Output dimension taking in to account stride
2095 outputs.append((output_no_stride // s[index]) + 1)
2096
2097 if fix_up_padding:
2098 p = tuple(p) # Make the paddings read-only
2099 assert min(outputs_no_stride) > 0, "Fix up did not work!"
2100 return p, remainders, outputs, outputs_no_stride
2101
2102 # Only fix up padding for conv2d and float types currently
2103 fix_up_padding = gtu.dtypeIsFloat(dtypes[0]) and op["op"] == Op.CONV2D
2104 # Allow any size of output dimension
2105 max_dim_size = None
2106 # Include all tests by default
2107 sparsity = 1
2108
2109 # Work out padding, strides and dilation ranges depending on
2110 # error and arguments
2111 if error_name in (
2112 ErrorIf.PadSmallerZero,
2113 ErrorIf.StrideSmallerOne,
2114 ErrorIf.DilationSmallerOne,
2115 ):
2116 # Use specific invalid value(s)
2117 if error_name == ErrorIf.PadSmallerZero:
2118 # Create negative paddings but with positive opposite paddings
2119 neg_pad = rng.choice(range(-5, 0))
2120 p_vals = [neg_pad, abs(neg_pad)]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002121 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002122 p_vals = [0, 0]
2123 if error_name == ErrorIf.StrideSmallerOne:
2124 # Can't use stride=0, as it is used to derive output shape, as a divisor
2125 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002126 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002127 s_vals = [1]
2128 if error_name == ErrorIf.DilationSmallerOne:
2129 d_vals = [rng.choice(range(-5, 1))]
2130 else:
2131 d_vals = [1]
2132 paddings = {tuple(p_vals) * k_rank}
2133 strides = {tuple(s_vals) * k_rank}
2134 dilations = {tuple(d_vals) * k_rank}
2135
2136 fix_up_padding = True # Need to fix up paddings to be valid
2137
2138 elif testGen.args.level8k and error_name is None:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002139 # Only test 8k levels boundaries
2140 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2141 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2142 bigPadding = bigKernel
2143
2144 dilation_shape = [1] * k_rank
2145 pad_shape = [0] * k_rank * 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002146 if op["op"] == Op.CONV3D:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002147 # Small stride apart from for big kernel (see below) to keep
2148 # tensor size/calculation small
2149 stride_shape = [1] * k_rank
2150 for idx in range(k_rank):
2151 pad_offset = idx * 2
2152 if k_shape[idx] == bigKernel:
2153 # Padding shape needs to account for tensor shape
2154 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2155 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2156 # Big stride to reduce output size
2157 stride_shape[idx] = bigKernel
2158 else:
2159 # Account for kernel size
2160 pad_shape[pad_offset] = k_shape[idx] - 1
2161 else:
2162 # Always have a large stride with extra padding and dilation to keep
2163 # tensor calculation reasonable
2164 stride_shape = [bigKernel] * k_rank
2165 for idx in range(k_rank):
2166 # Dilation shape must account for kernel size
2167 dilation_shape[idx] = bigKernel // k_shape[idx]
2168 # Padding shape needs to accommodate tensor/kernel & dilation
2169 pad_offset = idx * 2
2170 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2171 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2172
2173 strides = {tuple(stride_shape)}
2174 dilations = {tuple(dilation_shape)}
2175 paddings = {tuple(pad_shape)}
2176 # Create a limit for the output dimensions size
2177 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2178
2179 # Currently allow all combinations that are reasonable size
2180 sparsity = 1
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002181 else:
2182 # Generate comprehensive argument lists
2183 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
2184 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
2185 # Stride must be greater than 1 to force non-integer error
2186 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
2187 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
2188 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002189
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002190 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2191 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
2192
2193 if error_name is None and testGen.args.oversize:
2194 # add some oversize argument values
2195 if max(ifm_shape) < 64:
2196 bigPadding = 9
2197 paddings.update(
2198 {
2199 x
2200 for x in itertools.product(
2201 *([[0, bigPadding]] * (k_rank * 2))
2202 )
2203 }
2204 )
2205 bigStride = 8
2206 strides.update(
2207 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2208 )
2209 bigDilation = 7
2210 dilations.update(
2211 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2212 )
2213
2214 if error_name is None:
2215 # There are too many parameter combinations, so generate them sparsely,
2216 sparsity_factor = 120
2217 sparsity = TosaArgGen._calculate_sparsity(
2218 len(paddings) * len(strides) * len(dilations), sparsity_factor
2219 )
2220
2221 # Run through all the argument options creating valid test cases
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002222 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002223 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002224 for a in accum_dtypes:
2225 for s in sorted(list(strides)):
2226 for p in sorted(list(paddings)):
2227 for d in sorted(list(dilations)):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002228 if more_tests and (n % sparsity == 0):
2229 (
2230 p,
2231 remainders,
2232 outputs,
2233 outputs_no_stride,
2234 ) = get_conv_output_info(p, s, d, fix_up_padding)
2235 # Following is like checking each dimension N:
2236 # (ifm_shape[N+1] - 1 + p[N*2] + p[N*2+1]) > d[N] * (k_shape[N] - 1)
2237 if min(outputs_no_stride) <= 0:
2238 # Not a valid operation
2239 n += 1 # Increment count of tests
2240 continue
Tai Lyf36f2562024-03-14 16:21:29 +00002241
2242 if (
2243 # the parameters must produce integer exact output
2244 error_name != ErrorIf.ConvOutputShapeNonInteger
2245 and max(remainders) == 0
2246 ) or (
2247 error_name == ErrorIf.ConvOutputShapeNonInteger
2248 and max(remainders) > 0
2249 ):
2250 if (
2251 max_dim_size is not None
2252 and max(outputs) >= max_dim_size
2253 ):
2254 # Test will consume too much memory - skip it
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002255 logger.debug(
2256 "agConv: Convolution output too big - skipped"
2257 )
Tai Lyf36f2562024-03-14 16:21:29 +00002258 continue
2259
2260 # Compliance - number of dot product calculations
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002261 if op["op"] == Op.DEPTHWISE_CONV2D:
Tai Lyf36f2562024-03-14 16:21:29 +00002262 # N*OH*OW*C*M
2263 dots = gtu.product(
2264 (ifm_shape[0], *outputs, *filter_shape[2:])
2265 )
2266 else:
2267 # N*OH*OW*OC or N*OD*OH*OW*OC
2268 dots = gtu.product(
2269 (ifm_shape[0], *outputs, filter_shape[0])
2270 )
2271 args_dict = {
2272 "acc_type": a,
2273 "stride": s,
2274 "pad": p,
2275 "dilation": d,
2276 "kernel": k_shape,
2277 "ks": k_size,
2278 "dot_products": dots,
2279 "shape": ifm_shape,
2280 }
2281
2282 # Support for larger values than 9 needs different delimiter
2283 delim = "" if max(s + p + d) <= 9 else "x"
2284 arg_list.append(
2285 (
2286 "acc{}_st{}_pad{}_dilat{}".format(
2287 testGen.typeStr(a),
2288 delim.join([str(x) for x in s]),
2289 delim.join([str(x) for x in p]),
2290 delim.join([str(x) for x in d]),
2291 ),
2292 args_dict,
2293 )
2294 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002295 if (
2296 error_name
Jeremy Johnson87460262024-03-25 09:46:02 +00002297 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002298 ):
2299 # Found enough errors
2300 logger.debug(
2301 f"Skipping creating more conv error tests for {error_name}"
2302 )
2303 more_tests = False
Tai Lyf36f2562024-03-14 16:21:29 +00002304 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002305
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002306 arg_list = TosaArgGen._add_data_generators(
2307 testGen,
2308 opName,
evacha019c96eef2024-02-07 11:21:55 +00002309 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002310 dtypes[0],
2311 arg_list,
2312 error_name,
2313 )
2314 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002315 return arg_list
2316
2317 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002318 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002319
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002320 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002321 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002322
2323 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002324 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002325 elif error_name == ErrorIf.WrongInputType:
2326 # Pick some potentially correct output dtype if input type is incorrect
2327 accum_dtype = DType.INT32
2328 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002329 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002330
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002331 # Set up compliance info
2332 args_dict = {
2333 "acc_type": accum_dtype,
2334 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2335 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2336 "shape": shapeList[0],
2337 }
2338
2339 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2340
2341 arg_list = TosaArgGen._add_data_generators(
2342 testGen,
2343 opName,
evacha019c96eef2024-02-07 11:21:55 +00002344 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002345 input_dtype,
2346 arg_list,
2347 error_name,
2348 )
2349 # Return list of tuples: (arg_str, args_dict)
2350 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002351
2352 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002353 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002354 # Get valid accumulate type(s)
2355 if dtype == DType.INT8:
2356 accum_dtypes = [DType.INT32]
2357 elif dtype == DType.INT16:
2358 accum_dtypes = [DType.INT48]
2359 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002360 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002361 elif dtype == DType.BF16:
2362 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002363 elif dtype == DType.FP32:
2364 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002365 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2366 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002367 elif error_name is None:
2368 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2369
2370 if error_name == ErrorIf.WrongOutputType:
2371 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002372 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002373 elif error_name == ErrorIf.WrongInputType:
2374 # Pick some potentially correct output dtype if input type is incorrect
2375 accum_dtypes = [DType.INT32]
2376
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002377 # Set up compliance info
2378 args_dict = {
2379 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2380 # Set dot_products = N*H*W
2381 "dot_products": gtu.product(
2382 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2383 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002384 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002385 }
2386
2387 # Create arg tuple of string and dict
2388 arg_list = []
2389 for a in accum_dtypes:
2390 d = args_dict.copy()
2391 d["acc_type"] = a
2392 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002393
2394 arg_list = TosaArgGen._add_data_generators(
2395 testGen,
2396 opName,
evacha019c96eef2024-02-07 11:21:55 +00002397 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002398 dtype,
2399 arg_list,
2400 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002401 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002402 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002403 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002404
2405 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002406 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002407 arg_list = []
2408
Jeremy Johnson0c716862023-04-13 17:18:19 +01002409 if testGen.args.level8k and error_name is not None:
2410 # Don't produce negative large tests
2411 return arg_list
2412
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002413 ifm_shape = shapeList[0]
2414 filter_shape = shapeList[1]
2415
Tai Lyf36f2562024-03-14 16:21:29 +00002416 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2417
2418 if error_name == ErrorIf.WrongAccumulatorType:
2419 accum_dtypes = (
2420 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2421 )
James Ward8b390432022-08-12 20:48:56 +01002422
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002423 # Must be rank 4
2424 if error_name != ErrorIf.WrongRank:
2425 assert len(ifm_shape) == 4
2426 assert len(filter_shape) == 4
2427
Jeremy Johnson0c716862023-04-13 17:18:19 +01002428 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002429 # compliance size - KS
2430 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002431
Jeremy Johnson0c716862023-04-13 17:18:19 +01002432 if not testGen.args.level8k:
2433 # Generate comprehensive argument lists
2434 # - except for named errors, which use specific invalid value(s)
2435 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2436 if error_name == ErrorIf.PadLargerEqualKernel:
2437 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002438 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002439 else:
2440 p_vals = [
2441 x
2442 for x in range(
2443 smallest_padding_size, testGen.args.max_conv_padding + 1
2444 )
2445 ]
2446 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2447 if error_name == ErrorIf.StrideSmallerOne:
2448 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002449 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002450 else:
2451 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2452 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002453
Jeremy Johnson0c716862023-04-13 17:18:19 +01002454 if not error_name and testGen.args.oversize:
2455 # add some oversize argument values
2456 if max(ifm_shape) < 64:
2457 bigPadding = 9
2458 paddings.update(
2459 {
2460 x
2461 for x in itertools.product(
2462 *([[smallest_padding_size, bigPadding]] * 4)
2463 )
2464 }
2465 )
2466 bigStride = 8
2467 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2468
2469 # There are too many parameter combinations, so generate them sparsely,
2470 # very sparse for negative tests
2471 sparsity_factor = 2 if error_name else 10
2472 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2473 # If there are only a small number of tests, just select them all
2474 if sparsity < 13:
2475 sparsity = 1
2476 # To get a variety of parameter combinations sparsity should not be a
2477 # multiple of 2, 3 or 5
2478 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2479 sparsity += 1
2480 else:
2481 # Only test 8k levels boundaries
2482 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2483 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2484 bigPadding = bigKernel
2485
2486 pad_shape = [0] * (len(k_shape) * 2)
2487 stride_shape = [1] * len(k_shape)
2488 # The point at which input dimension combined with the stride will
2489 # create large output sizes!
2490 LARGE_SIZE = 2
2491 for idx in range(len(k_shape)):
2492 pad_offset = idx * 2
2493 if k_shape[idx] == bigKernel:
2494 # Set large stride
2495 stride_shape[idx] = bigKernel
2496 # Use negative output padding to reduce shape size
2497 pad_shape[pad_offset] = -(bigPadding - 1)
2498 if ifm_shape[idx + 1] > LARGE_SIZE:
2499 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2500 else:
2501 # The other dimension should be the bigKernel
2502 alt_idx = 1 - idx
2503 if (
2504 k_shape[alt_idx] == bigKernel
2505 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2506 ):
2507 # As the input is small, the large stride won't
2508 # affect the output so we can add some padding
2509 pad_shape[pad_offset + 1] = bigPadding
2510
2511 strides = {tuple(stride_shape)}
2512 paddings = {tuple(pad_shape)}
2513
2514 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002515 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002516
2517 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002518 for a in accum_dtypes:
2519 for s in sorted(list(strides)):
2520 for p in sorted(list(paddings)):
2521 if n % sparsity == 0:
2522 # Determine the output shape
2523 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2524 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2525 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002526
Tai Lyf36f2562024-03-14 16:21:29 +00002527 # N*OH*OW*OC
2528 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2529 args_dict = {
2530 "acc_type": a,
2531 "stride": s,
2532 "pad": p,
2533 "kernel": k_shape,
2534 "ks": k_size,
2535 "dot_products": dots,
2536 "shape": ifm_shape,
2537 "out_shape": os,
2538 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002539
Tai Lyf36f2562024-03-14 16:21:29 +00002540 # Support for larger values than 9 needs different delimiter
2541 delim = "" if max(s + p) <= 9 else "x"
2542 arg_list.append(
2543 (
2544 "acc{}_st{}_pad{}_os{}".format(
2545 testGen.typeStr(a),
2546 delim.join([str(x) for x in s]),
2547 delim.join([str(x) for x in p]),
2548 "x".join([str(x) for x in os]),
2549 ),
2550 args_dict,
2551 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002552 )
Tai Lyf36f2562024-03-14 16:21:29 +00002553 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002554
Jeremy Johnson95a67102024-01-10 14:16:39 +00002555 arg_list = TosaArgGen._add_data_generators(
2556 testGen,
2557 opName,
evacha019c96eef2024-02-07 11:21:55 +00002558 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002559 dtypes[0],
2560 arg_list,
2561 error_name,
2562 )
2563 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002564 return arg_list
2565
2566 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002567 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002568 rank = len(shapeList[0])
2569
Jeremy Johnson30a36842024-03-27 15:04:07 +00002570 if error_name is None and testGen.args.oversize:
2571 pad_values = [6, 7, 10, 13]
2572 elif error_name == ErrorIf.PadSmallerZero:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002573 pad_values = [x for x in range(-2, 0)]
Jeremy Johnson30a36842024-03-27 15:04:07 +00002574 else:
2575 # Exhaustively test combinations of padding on each side of each dimension
2576 # - the range of padding values is defined by pad_min and pad_max
2577 pad_min, pad_max = 0, 1
2578 pad_values = [x for x in range(pad_min, pad_max + 1)]
2579
2580 # Calculate pad combinations
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002581 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2582 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2583
2584 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002585 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002586 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002587 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002588 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002589 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002590 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002591 assert error_name == ErrorIf.WrongInputType
2592 pad_const_int = 0
2593 pad_const_fp = 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002594
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002595 list_shape_pad_values = list(shape_pad_values)
2596 # If we are producing tests for rank 6 or greater use sparsity
2597 if len(list_shape_pad_values) > 1024:
2598 sparsity_factor = 2 if error_name else 120
2599 sparsity = TosaArgGen._calculate_sparsity(
2600 len(list_shape_pad_values), sparsity_factor
2601 )
2602 else:
2603 sparsity = 1
2604
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002605 # Build arg list
2606 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002607 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002608 paddings = list(paddings)
2609 args_valid = True
2610
2611 if error_name == ErrorIf.PadSmallerZero:
2612 # Prevent negative output shapes while ensuring still testing for negative padding
2613 for i in range(rank):
2614 dim_after_padding = (
2615 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2616 )
2617 if dim_after_padding < 1:
2618 paddings[i] = (0, 0)
2619 if all([p > -1 for p in paddings[i]]):
2620 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002621 if args_valid and n % sparsity == 0:
Jeremy Johnson30a36842024-03-27 15:04:07 +00002622 # Work out name
2623 pad_list = []
James Ward8b390432022-08-12 20:48:56 +01002624 for r in range(rank):
Jeremy Johnson30a36842024-03-27 15:04:07 +00002625 pad_list.extend(paddings[r])
2626
2627 delim = "" if max(pad_list) <= 9 else "x"
2628 name = "pad{}".format(delim.join([str(x) for x in pad_list]))
2629
2630 args_dict = {
2631 "pad": np.array(paddings),
2632 "pad_const_int": pad_const_int,
2633 "pad_const_fp": pad_const_fp,
2634 }
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002635 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002636
2637 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002638 logger.debug(
2639 f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
2640 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002641
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002642 arg_list = TosaArgGen._add_data_generators(
2643 testGen,
2644 opName,
evacha019c96eef2024-02-07 11:21:55 +00002645 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002646 dtype,
2647 arg_list,
2648 error_name,
2649 )
2650
2651 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002652 return arg_list
2653
2654 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002655 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002656 arg_list = []
2657
2658 shape = shapeList[0]
2659 if error_name != ErrorIf.WrongRank:
2660 assert len(shape) == 4
2661
Jeremy Johnson0c716862023-04-13 17:18:19 +01002662 test_level8k = testGen.args.level8k and error_name is None
2663
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002664 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002665 startKernel = 2
2666 startPad = 0
2667 if not test_level8k:
2668 # Generate comprehensive argument lists
2669 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2670 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2671 # Stride must be greater than 1 to force non-integer error
2672 s_vals = [
2673 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2674 ]
2675 strides = {x for x in itertools.product(*([s_vals] * 2))}
2676 k_vals = [
2677 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2678 ]
2679 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2680 max_dim_size = None
2681 else:
2682 # Only test 8k levels
2683 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2684 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2685 strides = {(1, bigStride), (bigStride, 4)}
2686 kernels = {(1, bigKernel), (bigKernel, 3)}
2687 paddings = set()
2688 for s in sorted(list(strides)):
2689 for k in sorted(list(kernels)):
2690 padding = []
2691 for idx in range(len(k)):
2692 total_padding = s[idx] - shape[idx + 1] + k[idx]
2693 while total_padding < 0:
2694 # Must meet: shape + padding > kernel
2695 total_padding += s[idx]
2696 if total_padding < k[idx]:
2697 padding.extend([0, total_padding])
2698 else:
2699 # Note this may produce padding >= k[idx] which is not
2700 # allowed - but will be ignored in the creation loop below
2701 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2702 paddings.add(tuple(padding))
2703 # Create a limit for the output dimensions size
2704 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002705
James Ward8b390432022-08-12 20:48:56 +01002706 if opName == "max_pool2d":
2707 accum_dtypes = [None] # max_pool has no accumulate dtype
2708 elif dtype == DType.INT8 or dtype == DType.INT16:
2709 accum_dtypes = [DType.INT32]
2710 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002711 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002712 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002713 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002714 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2715 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002716 elif error_name is None:
2717 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2718 else:
2719 # Set to something for the ErrorIf case which has
2720 # incorrect input data-type
2721 accum_dtypes = [DType.INT32]
2722
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002723 if error_name == ErrorIf.WrongAccumulatorType:
2724 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2725
Jeremy Johnson0c716862023-04-13 17:18:19 +01002726 if not test_level8k:
2727 if testGen.args.oversize:
2728 # add some oversize argument values
2729 bigStride = 7
2730 bigKernel = 9
2731 strides.update(
2732 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002733 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002734 kernels.update(
2735 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2736 )
2737 if max(shape) < 64:
2738 # padding must be less than the kernel size
2739 bigPadding = bigKernel - 1
2740 paddings.update(
2741 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2742 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002743
Jeremy Johnson87460262024-03-25 09:46:02 +00002744 if error_name:
2745 # Cycle through all error_if tests but we only keep the first few
2746 sparsity = 1
2747 else:
2748 # There are too many parameter combinations, so generate them sparsely
2749 sparsity_factor = 500
2750 sparsity = (
2751 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2752 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002753 else:
2754 # We have already limited test output combinations for 8k tests
2755 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002756
James Ward8b390432022-08-12 20:48:56 +01002757 arg_str = (
2758 "acc{}_st{}_kern{}_pad{}"
2759 if accum_dtypes[0] is not None
2760 else "st{}_kern{}_pad{}"
2761 )
2762
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002763 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002764 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002765 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002766
2767 # Support for larger values than 9 needs different delimiter
2768 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002769 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002770 delim.join([str(x) for x in stride]),
2771 delim.join([str(x) for x in kern]),
2772 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002773 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002774 args_dict = {
2775 "stride": stride,
2776 "pad": pad,
2777 "kernel": kern,
2778 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002779 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002780 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2781 }
James Ward8b390432022-08-12 20:48:56 +01002782
2783 if accum is not None:
2784 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002785 args_dict["acc_type"] = accum
2786 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002787
Jeremy Johnson87460262024-03-25 09:46:02 +00002788 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002789 n = 0
James Ward8b390432022-08-12 20:48:56 +01002790 for a in accum_dtypes:
2791 for s in sorted(list(strides)):
2792 for p in sorted(list(paddings)):
2793 for k in sorted(list(kernels)):
2794 if error_name in [
2795 ErrorIf.StrideSmallerOne,
2796 ErrorIf.KernelSmallerOne,
2797 ErrorIf.PadSmallerZero,
2798 ErrorIf.PadLargerEqualKernel,
2799 ]:
2800 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002801 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002802 )
James Ward8b390432022-08-12 20:48:56 +01002803 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002804 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002805 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002806 )
James Ward8b390432022-08-12 20:48:56 +01002807 elif (
Jeremy Johnson87460262024-03-25 09:46:02 +00002808 more_tests
2809 and n % sparsity == 0
James Ward8b390432022-08-12 20:48:56 +01002810 # padding must not exceed the kernel size
2811 and p[0] < k[0]
2812 and p[1] < k[0]
2813 and p[2] < k[1]
2814 and p[3] < k[1]
2815 # the padded shape must exceed the kernel size
2816 and (shape[1] + p[0] + p[1]) > k[0]
2817 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002818 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002819 partial_h = shape[1] + p[0] + p[1] - k[0]
2820 partial_w = shape[2] + p[2] + p[3] - k[1]
2821 remainder_h = partial_h % s[0]
2822 remainder_w = partial_w % s[1]
2823 output_h = partial_h // s[0] + 1
2824 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002825 logger.debug(
2826 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2827 )
James Ward8b390432022-08-12 20:48:56 +01002828 if (
2829 # the parameters must produce integer exact output
2830 error_name != ErrorIf.PoolingOutputShapeNonInteger
2831 and remainder_h == 0
2832 and remainder_w == 0
2833 ) or (
2834 error_name == ErrorIf.PoolingOutputShapeNonInteger
2835 and (remainder_h != 0 or remainder_w != 0)
2836 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002837 if (
2838 max_dim_size is not None
2839 and max(output_h, output_w) > max_dim_size
2840 ):
2841 # Test will consume too much memory - skip it
2842 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002843 # Dot products = N*OH*OW*C
2844 dp = gtu.product(
2845 (shape[0], output_h, output_w, shape[3])
2846 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002847 arg_list.append(
2848 get_arg_list_element(a, s, p, k, dp, shape)
2849 )
Jeremy Johnson87460262024-03-25 09:46:02 +00002850 if (
2851 error_name
2852 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
2853 ):
2854 # Found enough errors
2855 logger.debug(
2856 f"Skipping creating more pooling error tests for {error_name}"
2857 )
2858 more_tests = False
2859
James Ward8b390432022-08-12 20:48:56 +01002860 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002861
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002862 # Now add data generator types
2863 arg_list = TosaArgGen._add_data_generators(
2864 testGen,
2865 opName,
evacha019c96eef2024-02-07 11:21:55 +00002866 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002867 dtype,
2868 arg_list,
2869 error_name,
2870 )
2871
2872 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002873 return arg_list
2874
2875 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002876 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002877 arg_list = []
2878
2879 # Enumerate the output types here
2880 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002881 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002882 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002883 dtypeList = [
2884 DType.BOOL,
2885 DType.INT16,
2886 DType.INT32,
2887 DType.FP16,
2888 DType.BF16,
2889 DType.FP32,
2890 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002891 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002892 dtypeList = [
2893 DType.BOOL,
2894 DType.INT8,
2895 DType.INT32,
2896 DType.FP16,
2897 DType.BF16,
2898 DType.FP32,
2899 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002900 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002901 dtypeList = [
2902 DType.BOOL,
2903 DType.INT8,
2904 DType.INT16,
2905 DType.FP16,
2906 DType.BF16,
2907 DType.FP32,
2908 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002909 elif inDtype == DType.BOOL:
2910 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002911 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002912 dtypeList = [
2913 DType.INT8,
2914 DType.INT16,
2915 DType.INT32,
2916 DType.FP32,
2917 DType.FP8E4M3,
2918 DType.FP8E5M2,
2919 ]
James Ward24dbc422022-10-19 12:20:31 +01002920 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002921 dtypeList = [
2922 DType.INT8,
2923 DType.INT16,
2924 DType.INT32,
2925 DType.FP32,
2926 DType.FP8E4M3,
2927 DType.FP8E5M2,
2928 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002929 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002930 dtypeList = [
2931 DType.INT8,
2932 DType.INT16,
2933 DType.INT32,
2934 DType.FP16,
2935 DType.BF16,
2936 DType.FP8E4M3,
2937 DType.FP8E5M2,
2938 ]
2939 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2940 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002941 elif error_name == ErrorIf.WrongInputType:
2942 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002943 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002944 else:
2945 raise Exception("Unexpected input dtype: {}".format(inDtype))
2946
2947 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002948 arg_list.append(
2949 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2950 )
2951
2952 # Now add data generator types
2953 arg_list = TosaArgGen._add_data_generators(
2954 testGen,
2955 opName,
evacha019c96eef2024-02-07 11:21:55 +00002956 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002957 dtype,
2958 arg_list,
2959 error_name,
2960 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002961
2962 return arg_list
2963
2964 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002965 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002966 arg_list = []
2967
2968 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002969 for outDtype in [
2970 DType.UINT8,
2971 DType.INT8,
2972 DType.INT16,
2973 DType.INT32,
2974 DType.UINT16,
2975 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002976 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002977 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002978 and error_name == ErrorIf.OutputZeroPointNotZero
2979 ):
2980 continue
2981 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002982 outDtype != DType.UINT16
2983 and error_name == ErrorIf.U16OutputZeroPointNotValid
2984 ) or (
2985 inDtype != DType.UINT16
2986 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002987 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002988 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002989 continue
2990 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002991 inDtype == DType.UINT8
2992 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002993 and error_name != ErrorIf.WrongOutputType
2994 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002995 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2996 continue
2997 if (
2998 inDtype not in [DType.INT8, DType.INT16]
2999 and outDtype == DType.UINT8
3000 and error_name != ErrorIf.WrongOutputType
3001 ):
3002 # The only input dtypes for UINT8 are INT8/INT16, skip all others
3003 continue
3004 if (
3005 inDtype == DType.UINT16
3006 and outDtype != DType.INT16
3007 and error_name != ErrorIf.WrongOutputType
3008 ):
3009 # The only output dtype for UINT16 is INT16, skip all others
3010 continue
3011 if (
3012 inDtype != DType.INT16
3013 and outDtype == DType.UINT16
3014 and error_name != ErrorIf.WrongOutputType
3015 ):
3016 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003017 continue
3018 if (
3019 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003020 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003021 ):
3022 continue
3023
3024 for scale32 in [False, True]:
3025 if error_name == ErrorIf.ScaleTrue and not scale32:
3026 continue
3027 elif error_name == ErrorIf.ScaleNotTrue and scale32:
3028 continue
3029 for double_round in [False, True]:
3030 if error_name == ErrorIf.ScaleNotTrue and not double_round:
3031 continue
Jeremy Johnson18a379d2024-03-28 15:53:21 +00003032 # Per_channel is only valid with rank > 0
3033 pc_options = (False, True) if len(shapeList[0]) > 0 else (False,)
3034 for per_channel in pc_options:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003035
3036 if (
3037 inDtype == DType.INT48
3038 and scale32
3039 and error_name != ErrorIf.ScaleTrue
3040 ):
3041 # Illegal condition. Must be scale32=False
3042 continue
3043 if (
3044 double_round
3045 and not scale32
3046 and error_name != ErrorIf.ScaleNotTrue
3047 ):
3048 # Illegal condition. ERROR_IF(!scale32 && double_round)
3049 continue
3050
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003051 if per_channel:
3052 nc = shapeList[0][-1]
3053 else:
3054 nc = 1
3055
3056 in_type_width = gtu.dtypeWidth(inDtype)
3057 out_type_width = gtu.dtypeWidth(outDtype)
3058
3059 # Calculate scale based on:
3060 # scale = a *(2^output_width)/(2^input_width))
3061
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003062 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003063 scale_arr = a * np.float32(
3064 (1 << out_type_width) / (1 << in_type_width)
3065 )
3066
3067 if scale32:
3068 # Cap the scaling at 2^31 - 1 for scale32
3069 scale_arr = np.clip(
3070 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
3071 )
3072 else:
3073 # Cap the scaling at 2^15 - 1 for scale16
3074 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3075
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003076 logger.debug(
3077 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
3078 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003079
3080 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3081 shift_arr = np.int32(np.zeros(shape=[nc]))
3082 for i in range(nc):
3083 (
3084 multiplier_arr[i],
3085 shift_arr[i],
3086 ) = TosaQuantGen.computeMultiplierAndShift(
3087 scale_arr[i], scale32
3088 )
3089
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003090 arg_list.append(
3091 (
3092 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01003093 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003094 int(scale32),
3095 int(double_round),
3096 int(per_channel),
3097 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00003098 {
3099 "output_dtype": outDtype,
3100 "scale": scale32,
3101 "double_round": double_round,
3102 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003103 "multiplier": multiplier_arr,
3104 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003105 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003106 )
3107 )
3108
Jeremy Johnson587cc842024-02-08 11:45:44 +00003109 arg_list = TosaArgGen._add_data_generators(
3110 testGen,
3111 opName,
evacha019c96eef2024-02-07 11:21:55 +00003112 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003113 inDtype,
3114 arg_list,
3115 error_name,
3116 )
3117 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003118 return arg_list
3119
3120 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003121 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003122 arg_list = []
3123
3124 if dtype is DType.INT32:
3125 for p in range(testGen.args.num_rand_permutations):
3126
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003127 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003128 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003129 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003130 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003131
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003132 arg_list = TosaArgGen._add_data_generators(
3133 testGen,
3134 opName,
evacha019c96eef2024-02-07 11:21:55 +00003135 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003136 dtype,
3137 arg_list,
3138 error_name,
3139 )
3140 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003141 return arg_list
3142
3143 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003144 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003145 arg_list = []
3146
Jeremy Johnson587cc842024-02-08 11:45:44 +00003147 for round in (True, False):
3148 args_dict = {
3149 "round": round,
3150 }
3151 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003152
Jeremy Johnson587cc842024-02-08 11:45:44 +00003153 arg_list = TosaArgGen._add_data_generators(
3154 testGen,
3155 opName,
evacha019c96eef2024-02-07 11:21:55 +00003156 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003157 dtype,
3158 arg_list,
3159 error_name,
3160 )
3161 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003162 return arg_list
3163
Luke Hutton57287132023-02-06 14:54:18 +00003164 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003165 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003166 arg_list = []
3167
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003168 shape = shapeList[0]
3169 dot_products = gtu.product(shape)
3170 ks = 2 * shape[1] * shape[2] # 2*H*W
3171 for inverse in (True, False):
3172 args_dict = {
3173 "dot_products": dot_products,
3174 "shape": shape,
3175 "ks": ks,
3176 "acc_type": dtype,
3177 "inverse": inverse,
3178 }
3179 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003180
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003181 arg_list = TosaArgGen._add_data_generators(
3182 testGen,
3183 opName,
evacha019c96eef2024-02-07 11:21:55 +00003184 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003185 dtype,
3186 arg_list,
3187 error_name,
3188 )
3189 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003190 return arg_list
3191
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003192 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003193 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003194 arg_list = []
3195
3196 shape = shapeList[0]
3197 dot_products = gtu.product(shape)
3198 ks = shape[1] * shape[2] # H*W
3199 args_dict = {
3200 "dot_products": dot_products,
3201 "shape": shape,
3202 "ks": ks,
3203 "acc_type": dtype,
3204 }
3205 arg_list.append(("", args_dict))
3206
3207 arg_list = TosaArgGen._add_data_generators(
3208 testGen,
3209 opName,
evacha019c96eef2024-02-07 11:21:55 +00003210 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003211 dtype,
3212 arg_list,
3213 error_name,
3214 )
3215 # Return list of tuples: (arg_str, args_dict)
3216 return arg_list
3217
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003218 # Helper function for reshape. Gets some factors of a larger number.
3219 @staticmethod
3220 def getFactors(val, start=1):
3221 factors = []
3222
3223 for i in range(start, int(np.sqrt(val)) + 1):
3224 if (val % i) == 0:
3225 factors.append(i)
3226
3227 return factors
3228
3229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003230 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003231 arg_list = []
3232
3233 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003234 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003235 factors = TosaArgGen.getFactors(totalElements)
3236
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003237 # Find new shapes up to the number of permutations asked for
3238 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003239 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003240 # Rank from 1 to MAX_TENSOR_RANK
3241 newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003242 if len(factors) < newRank:
3243 continue
3244
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003245 # escape_counter limits the generation of new shapes to a reasonable time
3246 for escape_counter in range(100):
3247
3248 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003249 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003250 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003251 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003252 for i in range(1, newRank):
3253 # pick rank-1 factors
3254 newShape.append(shuffledFactors[0])
3255 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003256 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003257 TosaArgGen.getFactors(remainingElements)
3258 )
3259 newShape.append(remainingElements)
3260
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003261 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003262 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003263 for name, args_dict in arg_list:
3264 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003265 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003266 break
3267
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003268 if not duplicate:
3269 outShape = "x".join([str(x) for x in newShape])
3270 arg_list.append(
3271 (
3272 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3273 {"new_shape": newShape},
3274 )
3275 )
3276 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003277 break
3278
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003279 # Now add data generator types
3280 arg_list = TosaArgGen._add_data_generators(
3281 testGen,
3282 opName,
evacha019c96eef2024-02-07 11:21:55 +00003283 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003284 dtype,
3285 arg_list,
3286 error_name,
3287 )
3288
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 return arg_list
3290
3291 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003292 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003293 arg_list = []
3294
3295 ifm_shape = shapeList[0]
3296
3297 if error_name == ErrorIf.IndexOutsideBounds:
3298 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3299 incorrect_small_index = range(-len(ifm_shape), 0)
3300 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3301 permutations.extend(
3302 [p for p in itertools.permutations(incorrect_small_index)]
3303 )
3304 elif error_name == ErrorIf.IndexUsedTwice:
3305 # Create list with a duplicated index
3306 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003307 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003308 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3309 permutations = [p for p in itertools.permutations(perm_range)]
3310
3311 else:
3312 # Get all permutations
3313 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3314
3315 # Limit to possible permutations from shape dimension or argument setting
3316 limit = min(len(permutations), testGen.args.num_rand_permutations)
3317
3318 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003319 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003320
3321 # Create list of required amount of permutations
3322 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003323 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003324 for p in range(limit)
3325 ]
evacha0198477222024-01-26 12:25:32 +00003326 # Now add data generator types
3327 arg_list = TosaArgGen._add_data_generators(
3328 testGen,
3329 opName,
evacha019c96eef2024-02-07 11:21:55 +00003330 shapeList,
evacha0198477222024-01-26 12:25:32 +00003331 dtype,
3332 arg_list,
3333 error_name,
3334 )
3335 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003336 return arg_list
3337
3338 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003339 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003340 arg_list = []
3341
3342 ifm_shape = shapeList[0]
3343 rank = len(ifm_shape)
3344
3345 for p in range(testGen.args.num_rand_permutations):
3346 start = []
3347 size = []
3348
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003349 for i in range(rank):
3350 if ifm_shape[i] > 1:
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003351 # Start from 0 to dimension size - 1 to leave room for slice of 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003352 start.append(rng.randInt(0, ifm_shape[i]))
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003353 # Size from 1 up to rest of room (dimension size - start)
3354 size.append(rng.randInt(1, ifm_shape[i] + 1 - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003355
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003356 # Should never hit an invalid slice size
3357 assert size[i] > 0 and (size[i] + start[i]) <= ifm_shape[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003358 else:
3359 start.append(0)
3360 size.append(1)
3361
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003362 # If ERROR_IF test required then incorrect start, size will be returned
3363 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
3364 rng, error_name, ifm_shape, start, size
3365 )
3366 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3367
evacha017f7d4252024-01-24 12:08:09 +00003368 # Now add data generator types
3369 arg_list = TosaArgGen._add_data_generators(
3370 testGen,
3371 opName,
evacha019c96eef2024-02-07 11:21:55 +00003372 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003373 dtype,
3374 arg_list,
3375 error_name,
3376 )
3377 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003378 return arg_list
3379
3380 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003381 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003382 arg_list = []
3383
3384 ifm_shape = shapeList[0]
3385 rank = len(ifm_shape)
3386
3387 for p in range(testGen.args.num_rand_permutations):
3388
3389 # Pick a few random, but small multiple values
3390 # because otherwise this has a tendency to generate
3391 # enormous tensors
3392 multiples = []
3393 for i in range(rank):
3394 if ifm_shape[i] > 1000:
3395 # Multiple of 1 if ifm_shape dimension is large to reduce
3396 # tensor size
3397 multiples.append(1)
3398 elif max(ifm_shape) > 1000:
3399 multiples.append(2)
3400 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003401 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003402 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003403
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003404 # Now add data generator types
3405 arg_list = TosaArgGen._add_data_generators(
3406 testGen,
3407 opName,
evacha019c96eef2024-02-07 11:21:55 +00003408 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003409 dtype,
3410 arg_list,
3411 error_name,
3412 )
3413 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003414 return arg_list
3415
3416 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003417 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003418 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003419 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003420
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003421 def get_aspect_ratio_resize_params():
3422 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003423 aspect_ratio = rng.choice(common_aspect_ratios)
3424 invert = rng.choice((False, True))
3425 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003426
3427 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3428 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3429 scale_y_d = scale_x_d = 1
3430 offset_x = offset_y = 0
3431
3432 if letterbox:
3433 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003434 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003435 border_x = 0
3436 else:
3437 # Pillarboxing
3438 border_y = 0
3439 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003440 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003441
3442 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3443 offset = (offset_y, offset_x)
3444 border = (border_y, border_x)
3445
3446 return scale, offset, border
3447
3448 def get_upscale_downscale_params():
3449 valid_params = False
3450 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003451 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003452
3453 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003454 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003455
3456 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003457 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003458 scale_x_d = scale_y_d = 1
3459 scale_x_n = scale_y_n = (
3460 1 << shift if origin_sampling else 2 << shift
3461 )
3462 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3463 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3464 else:
3465 scale_x_n = 1
3466 scale_y_n = 1
3467
3468 # Return list of valid scale_*_d values (max value 4) given input dim shape
3469 def get_valid_denom(ifm_dim):
3470 return [x for x in range(1, 5) if ifm_dim % x == 1]
3471
3472 # Generate list of valid downscale values and choose one randomly
3473 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3474 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3475
3476 if not valid_scale_y_ds and not valid_scale_x_ds:
3477 # Bad parameters, skip
3478 continue
3479
3480 if not valid_scale_y_ds:
3481 scale_y_d = 1
3482 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003483 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003484
3485 if not valid_scale_x_ds:
3486 scale_x_d = 1
3487 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003488 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003489
3490 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003491 offset_y = rng.randInt(0, 16 * scale_y_n)
3492 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003493 valid_params = True
3494
3495 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3496 offset = (offset_y, offset_x)
3497 border = (border_y, border_x)
3498 return scale, offset, border
3499
3500 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003501 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3502 scale = scale_n / scale_d
3503 if scale > max_scale:
3504 factor = scale / max_scale
3505 new_scale_d = math.ceil(scale_d * factor)
3506 assert scale_n / new_scale_d <= max_scale
3507 scale_d = new_scale_d
3508 return scale_d
3509
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003510 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003511 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3512 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003513
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003514 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3515 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003516
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003517 scale_y_d = fix_scale_to_max_scale(
3518 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3519 )
3520 scale_x_d = fix_scale_to_max_scale(
3521 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3522 )
3523
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003524 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003525 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3526 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3527 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3528 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003529
3530 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3531 offset = (offset_y, offset_x)
3532 border = (border_y, border_x)
3533 return scale, offset, border
3534
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003535 def get_level_8k_params():
3536 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003537 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003538 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3539 )
3540 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3541 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003542 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003543 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003544 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003545 if switch:
3546 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3547 else:
3548 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3549
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003550 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3551 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003552 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003553 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3554 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003555 border = (border_y, border_x)
3556 return scale, offset, border
3557
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003558 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003559 # Exclude illegal {mode, type} configurations. Pick legal output types
3560 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3561 outputDTypeList = [DType.INT8]
3562 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3563 outputDTypeList = [DType.INT16]
3564 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3565 outputDTypeList = [DType.INT32]
3566 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3567 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003568 elif dtype == DType.FP16:
3569 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003570 elif dtype == DType.BF16:
3571 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003572 elif dtype == DType.FP32:
3573 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003574 elif dtype == DType.FP8E4M3:
3575 outputDTypeList = [DType.FP8E4M3]
3576 elif dtype == DType.FP8E5M2:
3577 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003578 elif error_name == ErrorIf.WrongInputType:
3579 # If an incorrect input type is used then we set a 'correct'
3580 # output type to avoid other errors
3581 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3582 else:
3583 continue
3584
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003585 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3586
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003587 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003588 perm = 0
3589 while perm < testGen.args.num_rand_permutations:
3590 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003591 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003592 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003593 (
3594 get_rand_params,
3595 get_upscale_downscale_params,
3596 get_aspect_ratio_resize_params,
3597 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003598 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003599 scale, offset, border = _rnd_param_fn()
3600 else:
3601 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003602
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003603 # Expand params for bounds-checking
3604 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3605 (offset_y, offset_x) = offset
3606 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003607
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003608 # Make sure output dimensions OH and OW are integers
3609 partial_output_y = (
3610 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3611 )
3612 partial_output_x = (
3613 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3614 )
3615 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003616 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003617 if (
3618 partial_output_y % scale_y_d == 0
3619 and partial_output_x % scale_x_d == 0
3620 ):
3621 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003622 if perm > 0:
3623 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003624 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003625 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003626 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003627 while partial_output_y % scale_y_d != 0:
3628 scale_y_d -= 1
3629 while partial_output_x % scale_x_d != 0:
3630 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003631 # Make sure we are still within max scaling
3632 if (
3633 scale_y_n / scale_y_d
3634 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3635 scale_x_n / scale_x_d
3636 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3637 # Skip the test as it is using too large a scaling factor
3638 if perm > 0:
3639 perm += 1
3640 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003641
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003642 output_y = partial_output_y // scale_y_d + 1
3643 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003644
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003645 if (
3646 output_y >= testGen.args.max_resize_output_dim
3647 or output_x >= testGen.args.max_resize_output_dim
3648 ) and error_name is None:
3649 # Skip positive test if output dim will be too high
3650 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003651 if not testGen.args.level8k or perm > 0:
3652 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003653 continue
3654
3655 if (
3656 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003657 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003658 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003659 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003660 ):
3661 # Output dimensions out of scope
3662 if error_name is not None and perm > 0:
3663 # As long as we have one ERROR_IF test, don't worry
3664 # about creating all the other permutations
3665 perm += 1
3666 continue
3667
3668 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3669 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003670 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003671 and output_y - scale_y_d < 1
3672 )
3673 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003674 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003675 and output_x - scale_x_d < 1
3676 )
3677 ):
3678 # Can't create a negative test with these params as it
3679 # will create invalid output size
3680 if perm > 0:
3681 perm += 1
3682 continue
3683
3684 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3685 offset = [offset_y, offset_x]
3686 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003687
3688 # Common for all data types
3689 if error_name is not None:
3690 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003691 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003692 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003693 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003694 outputDTypeNew,
3695 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003696 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 error_name,
3698 mode,
3699 dtype,
3700 shapeList,
3701 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003702 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003703 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003704 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003705 )
3706 else:
3707 outputDTypeNew = outputDType
3708
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003709 arg_to_append = (
3710 arg_str.format(
3711 "N" if mode == ResizeMode.NEAREST else "B",
3712 testGen.typeStr(outputDTypeNew),
3713 scale[0],
3714 scale[1],
3715 scale[2],
3716 scale[3],
3717 offset[0],
3718 offset[1],
3719 border[0],
3720 border[1],
3721 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003722 {
3723 "mode": mode,
3724 "scale": scale,
3725 "offset": offset,
3726 "border": border,
3727 "output_dtype": outputDTypeNew,
3728 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003729 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003730 if arg_to_append in arg_list:
3731 # Skip already generated test params
3732 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003733
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003734 # Valid permutation
3735 perm += 1
3736 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003737
3738 # Now add data generator types
3739 arg_list = TosaArgGen._add_data_generators(
3740 testGen,
3741 opName,
evacha019c96eef2024-02-07 11:21:55 +00003742 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003743 dtype,
3744 arg_list,
3745 error_name,
3746 )
3747 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003748 return arg_list
3749
3750 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003751 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 arg_list = []
3753
3754 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003755 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003756 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003757 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003758 # Make sure all slopes are within REQUIRE min/max 16-bit int
3759 for idx in range(len(table) - 1):
3760 slope = table[idx + 1] - table[idx]
3761 # Alter the next table entry to force the slope to be ok
3762 if slope > 32767:
3763 table[idx + 1] -= slope - 32767
3764 if slope < -32768:
3765 table[idx + 1] -= slope + 32768
3766 slope = table[idx + 1] - table[idx]
3767 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003768 arg_list.append(
3769 (
3770 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003771 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 )
3773 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003774 # Now add data generator types
3775 arg_list = TosaArgGen._add_data_generators(
3776 testGen,
3777 opName,
evacha019c96eef2024-02-07 11:21:55 +00003778 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003779 dtype,
3780 arg_list,
3781 error_name,
3782 )
3783 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 return arg_list
3785
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003786 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003787 # CondIf generates the condition values here.
3788 # Convert to tensors in the build function, along with the
3789 # then and else blocks
3790 arg_list = []
3791
3792 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003793 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003794
Jeremy Johnson587cc842024-02-08 11:45:44 +00003795 # Now add data generator types
3796 arg_list = TosaArgGen._add_data_generators(
3797 testGen,
3798 opName,
evacha019c96eef2024-02-07 11:21:55 +00003799 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003800 dtype,
3801 arg_list,
3802 error_name,
3803 )
3804 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003805 return arg_list
3806
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003807 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003808 # While loop: 0 iterations, 1, more than 1
3809 arg_list = []
3810
Jeremy Johnson587cc842024-02-08 11:45:44 +00003811 for iterations in [0, 1, 4]:
3812 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003813
Jeremy Johnson587cc842024-02-08 11:45:44 +00003814 # Now add data generator types
3815 arg_list = TosaArgGen._add_data_generators(
3816 testGen,
3817 opName,
evacha019c96eef2024-02-07 11:21:55 +00003818 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003819 dtype,
3820 arg_list,
3821 error_name,
3822 )
3823 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003824 return arg_list