blob: 79d4e788360d7ac7efdf795906b9457c02ce19ca [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
262 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100263 shape_list = []
264
265 # Choose one of the inputs to broadcast
266 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100267 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
268 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000269
Jeremy Johnson0a042992024-02-28 13:20:05 +0000270 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100271 shape_bcast = shape.copy()
272
Jerry Ge135c9552023-05-23 20:59:32 +0000273 # To test broadcasting, the chosen fuzz index dimension should not be 1
274 if shape_bcast[fuzz_idx] == 1:
275 shape_bcast[fuzz_idx] += 1
276
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100277 # If the chosen input, pick a random index to broadcast
278 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000279 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100280 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100281 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100282 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100283 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 )
285 if rank != 1:
286 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100287 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000289 elif error_name == ErrorIf.BroadcastShapesMismatch:
290 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 else:
292 shape_bcast[fuzz_idx] = 1
293
294 shape_list.append(shape_bcast)
295
296 return shape_list
297
298 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100299 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000300 pl, const = op["operands"]
301 num_shapes = pl + const
302 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100303 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000304 )
305
306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100307 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000308 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 shape_list = TosaTensorGen._get_broadcast_shapes(
310 testGen, rng, 2, rank, error_name
311 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000312 # Add a single dimension tensor for shift
313 shape_list.append([1])
314 return shape_list
315
316 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100317 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318 pl, const = op["operands"]
319
320 if error_name != ErrorIf.WrongRank:
321 assert rank == 4
322
323 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000325 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326
327 # Constrict the overall size of the shape when creating ERROR_IF tests
328 if error_name:
329 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
330 ifm_shape, max_dim=24, max_items=10000
331 )
332
333 # Get the filter height/width from the operator parameters
334 filter_hw = op["filter"]
335
336 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100337 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # The filter dimensions are OHWI
340 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
341
Jeremy Johnson5e36bde2024-03-14 16:56:10 +0000342 # The bias is OC or 1 if broadcastable
343 try:
344 if op["broadcastable_bias"]:
345 if rng.choice([True, False]):
346 ofm_depth = 1
347 except KeyError:
348 pass
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 bias_shape = np.asarray([ofm_depth])
350
351 return [ifm_shape, filter_shape, bias_shape]
352
353 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100354 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100355 pl, const = op["operands"]
356
357 if error_name != ErrorIf.WrongRank:
358 assert rank == 5
359
360 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100361 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000362 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # Constrict the overall size of the shape when creating ERROR_IF tests
365 if error_name:
366 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
367 ifm_shape, max_dim=24, max_items=10000
368 )
369
370 # Get the filter depth/height/width from the operator parameters
371 filter_dhw = op["filter"]
372
373 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100374 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100375
376 # The filter dimensions are ODHWI
377 filter_shape = np.asarray(
378 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
379 )
380
381 # The bias is OC
382 bias_shape = np.asarray([ofm_channel])
383
384 return [ifm_shape, filter_shape, bias_shape]
385
386 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100387 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100388 pl, const = op["operands"]
389
390 if error_name != ErrorIf.WrongRank:
391 assert rank == 4
392
393 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100394 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000395 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100396
397 # Constrict the overall size of the shape when creating ERROR_IF tests
398 if error_name:
399 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
400 ifm_shape, max_dim=24, max_items=10000
401 )
402
403 # Get the filter height/width from the operator parameters
404 filter_hw = op["filter"]
405
406 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100407 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100408
409 # The filter dimensions are OHWI
410 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
411
412 # The bias is OC
413 bias_shape = np.asarray([ofm_depth])
414
415 return [ifm_shape, filter_shape, bias_shape]
416
417 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100418 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100419 pl, const = op["operands"]
420
421 if error_name != ErrorIf.WrongRank:
422 assert rank == 4
423 assert pl == 1 and const == 2
424
425 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100426 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000427 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100428
429 # Constrict the overall size of the shape when creating ERROR_IF tests
430 if error_name:
431 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
432 ifm_shape, max_dim=24, max_items=10000
433 )
434
435 # Get the filter height/width from the operator parameters
436 # Filter is KH, HW, C, M
437 filter_hw = op["filter"]
438
439 # Generate a random OFM depth, but don't let it get too big because
440 # the output depth is M * C
441 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100442 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100443 ) + 1
444
445 # The filter dimensions are HWCM
446 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
447
448 # The bias is M * C
449 bias_shape = np.asarray([ifm_shape[3] * filter_m])
450
451 return [ifm_shape, filter_shape, bias_shape]
452
453 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100454 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000455 pl, const = op["operands"]
456
457 if error_name != ErrorIf.WrongRank:
458 assert rank == 3
459 assert pl == 2 and const == 0
460
461 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100462 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000463
464 # Select nearest lower power of two from input height and width
465 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
466 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
467
468 # Constrict the overall size of the shape when creating ERROR_IF tests
469 if error_name:
470 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
471
472 # Generate an invalid kernel that is not a power of two
473 if error_name == ErrorIf.KernelNotPowerOfTwo:
474 inc_h = 2 if ifm_shape[1] == 1 else 1
475 inc_w = 2 if ifm_shape[2] == 1 else 1
476 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100477 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000478 ifm_shape[1] += selected_inc[0]
479 ifm_shape[2] += selected_inc[1]
480
481 ifm_shape = testGen.constrictBatchSize(ifm_shape)
482
483 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
484 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100485 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000486 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100487 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000488 ifm_shapes[modify_shape][modify_dim] *= 2
489
490 return [ifm_shapes[0], ifm_shapes[1]]
491
492 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100493 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000494 pl, const = op["operands"]
495
496 if error_name != ErrorIf.WrongRank:
497 assert rank == 3
498 assert pl == 1 and const == 0
499
500 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100501 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000502
503 # Select nearest lower power of two from input height and width
504 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
505 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
506
507 # Constrict the overall size of the shape when creating ERROR_IF tests
508 if error_name:
509 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
510
511 # Generate an invalid kernel that is not a power of two
512 if error_name == ErrorIf.KernelNotPowerOfTwo:
513 # We must increment by 2 if current size is 1
514 inc_h = 2 if ifm_shape[1] == 1 else 1
515 inc_w = 2 if ifm_shape[2] == 1 else 1
516 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100517 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000518 ifm_shape[1] += selected_inc[0]
519 ifm_shape[2] += selected_inc[1]
520
James Ward30124a82023-02-02 14:56:33 +0000521 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000522
523 return [ifm_shape]
524
525 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527 pl, const = op["operands"]
528
529 if error_name != ErrorIf.WrongRank:
530 assert rank == 2
531
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100532 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
537
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100538 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100539 low=testGen.args.tensor_shape_range[0],
540 high=testGen.args.tensor_shape_range[1],
541 size=1,
542 )[0]
543 filter_shape = np.asarray([filter_oc, input_shape[1]])
544
545 bias_shape = np.asarray([filter_oc])
546
547 return [input_shape, filter_shape, bias_shape]
548
549 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100550 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100551 pl, const = op["operands"]
552
553 if error_name != ErrorIf.WrongRank:
554 assert rank == 3
555 assert pl == 2 and const == 0
556
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100557 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100558
559 # Constrict the overall size of the shape when creating ERROR_IF tests
560 if error_name:
561 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
562
563 # Get a random number for b_oc even if target shape is defined
564 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100565 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100566 low=testGen.args.tensor_shape_range[0],
567 high=testGen.args.tensor_shape_range[1],
568 size=1,
569 )
570 )[0]
571 # If N or H is large let b_oc be 1 to reduce output tensor size
572 if max(a_shape) > 1000:
573 b_oc = 1
574
575 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
576 return [a_shape, b_shape]
577
578 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100579 def tgConcat(testGen, rng, op, rank, error_name=None):
580 pl, const = op["operands"]
581 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100582
583 # Create extra tensors to concat.
584 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100585 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100586 shape_list = []
587 for i in range(pl + const + num_tensors):
588 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100589 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100590 wrongShape = shape.copy()
591
592 if remove and len(shape) > 1:
593 wrongShape = wrongShape[1:]
594 else:
595 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100596 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100597
598 shape_list.append(wrongShape)
599 else:
600 shape_list.append(shape.copy())
601
602 return shape_list
603
604 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100605 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100606 if error_name in [
607 ErrorIf.AxisSmallerZero,
608 ErrorIf.AxisLargerRank,
609 ErrorIf.ConcatInputRankMismatch,
610 ]:
611 return shapeList
612
613 # Split concat shape along axis to allow for multiple const inputs
614 # without making too many large tensors
615 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
616 # If axis can't be split we still need to invalidate other dimensions
617 if error_name == ErrorIf.ConcatInputDimMismatch:
618 for shape in shapeList[1:]:
619 # Negative test shapeLists are created individually for each test,
620 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100621 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100622 return shapeList
623
624 # Create copy of shape we are going to split (so we don't alter shapeList)
625 shape = shapeList[0].copy()
626 # Add original shape as first input
627 new_shapeList = [shape.copy()]
628 length_on_axis = shape[axis]
629 remaining_length = length_on_axis
630 for i in range(len(shapeList) - 2):
631 # Calculate split on axis and remaining value
632 split_shape_val = int(shape[axis] / 2)
633 remaining_length = remaining_length - split_shape_val
634
635 # Append new shape, and set remaining shape
636 shape[axis] = split_shape_val
637 new_shapeList.append(shape.copy())
638
639 # invalidate dimensions
640 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100641 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100642 else:
643 shape[axis] = remaining_length
644
645 if i == len(shapeList) - 3:
646 new_shapeList.append(shape.copy())
647
648 return new_shapeList
649
650
651class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100652 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100653
654 def __init__(self):
655 pass
656
Jeremy Johnson1271c442023-09-05 11:39:26 +0100657 class TVGInfo:
658 """Enhanced tensor values information including data gen dict."""
659
660 def __init__(self, tensorList, dataGenDict):
661 self.tensorList = tensorList
662 self.dataGenDict = dataGenDict
663
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100664 # Default high value for random numbers
665 TVG_FLOAT_HIGH_VALUE = {
666 DType.FP32: (1 << 128) - (1 << (127 - 23)),
667 DType.FP16: (1 << 16) - (1 << (15 - 10)),
668 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000669 DType.FP8E4M3: 448,
670 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100671 }
672
Jeremy Johnson30476252023-11-20 16:15:30 +0000673 # Default lowest normal values for random numbers
674 TVG_FLOAT_LOW_VALUE = {
675 DType.FP32: np.exp2(-126),
676 DType.FP16: np.exp2(-14),
677 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000678 DType.FP8E4M3: np.exp2(-9),
679 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000680 }
681
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100682 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100683 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000684 # Return a tuple of (low,high) data range values for the given data
685 # type using a combination of per operator table limits, data limits
686 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000687 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100688 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000689 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000690 if lowValueLookup is not None and dtype in lowValueLookup:
691 low_val = lowValueLookup[dtype]
692 else:
693 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000694 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000695 # respecting the default ranges if more/less than the low/high
696 # values
697 data_range = (
698 max(low_val, type_range[0]),
699 min(high_val, type_range[1]),
700 )
701 if data_range[0] > data_range[1]:
702 # Invalid data range from low to high created due to user
703 # constraints revert to using internal ranges as they are
704 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000705 logger.info(
706 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
707 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000708 data_range = (low_val, high_val)
709 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000710 return None
711
712 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100713 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100714 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100715 ):
716 # Variable inputs versus constants
717 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000718 if "p_count" in argsDict:
719 # Override for operators like CONCAT
720 pCount = argsDict["p_count"]
721 cCount = argsDict["c_count"]
722 assert pCount + cCount == len(
723 shapeList
724 ), "Placeholders & Constant tensors must match shapes list"
725
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000726 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100727
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100728 if (
729 error_name is not None
730 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100731 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100732 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000733 # Fall back to internal data gen when dealing with unsupported types or ops
734 data_range = argsDict["data_range"] if "data_range" in argsDict else None
735 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000736 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000737 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000738 if "data_range_list" in argsDict:
739 data_range = argsDict["data_range_list"][idx]["range"]
740 roundMode = (
741 "round" in argsDict["data_range_list"][idx]
742 and argsDict["data_range_list"][idx]["round"] is True
743 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000744 if data_range is not None and dtype not in (
745 DType.FP16,
746 DType.FP32,
747 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000748 DType.FP8E4M3,
749 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000750 ):
751 # Change from inclusive to exclusive range
752 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000753
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100754 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000755 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000756 if dtype == DType.SHAPE:
757 arr = np.int64(argsDict["fixed_data"][idx])
758 elif dtype == DType.INT8:
759 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000760 elif dtype == DType.INT16:
761 arr = np.int16(argsDict["fixed_data"][idx])
762 elif dtype == DType.INT32:
763 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000764 else:
765 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000766 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100767 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000768 if roundMode:
769 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000770 if idx < pCount:
771 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
772 else:
773 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100774
Jeremy Johnson1271c442023-09-05 11:39:26 +0100775 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
776
777 # Create data generator meta-data
778 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100779 tens_data = {
780 "version": "0.1",
781 "tensors": {},
782 }
783 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100784 for idx, shape in enumerate(shapeList):
785
786 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000787 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
788 tens_meta["generator"] = gtu.DataGenType(
789 gtu.DataGenType.FIXED_DATA
790 ).name
791 else:
792 tens_meta["generator"] = gtu.DataGenType(dg_type).name
793
Jeremy Johnson1271c442023-09-05 11:39:26 +0100794 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
795 tens_meta["shape"] = [int(i) for i in shape]
796 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100797 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100798
799 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100800 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100801 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100802 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100803
804 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
805 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000806 if (
807 tens_meta["generator"]
808 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
809 ):
810 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
811 tens_meta["fixed_data_info"] = info
812 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100813 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000814
Won Jeon64e4bfe2024-01-18 06:31:55 +0000815 data_range = None
816 if "data_range_list" in argsDict:
817 data_range = argsDict["data_range_list"][idx]["range"]
818 if "round" in argsDict["data_range_list"][idx]:
819 info["round"] = argsDict["data_range_list"][idx]["round"]
820 elif "data_range" in argsDict:
821 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000822
Won Jeon64e4bfe2024-01-18 06:31:55 +0000823 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100824 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000825 info["range"] = [str(v) for v in data_range]
826 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100827 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
828 info = {}
829 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100830 info["ks"] = int(argsDict["ks"])
831 if "acc_type" in argsDict:
832 # Convert type number into JSON name
833 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
834 "json"
835 ]
836 if "kernel" in argsDict:
837 info["kernel"] = [int(k) for k in argsDict["kernel"]]
838 if "axis" in argsDict:
839 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100840 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000841 elif dg_type == gtu.DataGenType.FULL_RANGE:
842 info = {}
843 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100844 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000845 )
846 tens_meta["full_range_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100847 else:
848 # TODO - other data gen type
849 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850
851 # Using the finished generate config meta data - generate the data if
852 # needed and assign a tensor name from the serializer
853
854 # Need to generate data when not lazy or for the bias tensor as we need
855 # to work out if the bias data is non-zero for compliance
856 if not testGen.args.lazy_data_gen or (
857 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
858 ):
859 # Give this tensor a temporary name until we get one from the serializer
860 temp_name = f"placeholder_{idx}"
861 dg_tens_meta[temp_name] = tens_meta
862 # Create data now using the temporary name to access meta details
863 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000864 if tens_meta["data_type"] == "SHAPE":
865 # Tensor type SHAPE and Numpy file type must be the same
866 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100867 # Remove the item as we will give it the correct name later
868 del dg_tens_meta[temp_name]
869
870 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
871 # The KS value used by compliance verification is altered when the
872 # bias data is non-zero
873 if max(abs(data)) > 0.0:
874 argsDict["ksb"] = argsDict["ks"] + 1
875
876 if testGen.args.lazy_data_gen:
877 data = None
878
879 if tens_meta["input_type"] == "VARIABLE":
880 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
881 else:
882 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
883
884 tens_ser_list.append(tens)
885 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100886 dg_tens_meta[tens.name] = tens_meta
887
Jeremy Johnson1271c442023-09-05 11:39:26 +0100888 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
889
890 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100891 def tvgNegate(
892 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
893 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100894 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000895 # Integer test
896 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100897 pCount, cCount = op["operands"]
898 assert (
899 pCount == 1 and cCount == 0
900 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100901 # Must create tensors with values within accumulator (int32) negatable
902 # range
903 max_val = (1 << 31) - 1
904 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100905 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100906 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000908 tens_ser_list = []
909 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100910 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
911 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000912 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100913 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000914 # ERROR_IF or floating point test
915 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100916 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100917 )
918
Jeremy Johnson30476252023-11-20 16:15:30 +0000919 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000920 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
921 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
922 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
923 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
924 }
925
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100927 def tvgAddSub(
928 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
929 ):
Won Jeon74342e52024-01-09 00:34:40 +0000930 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000931 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100932 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000933 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100934 pCount, cCount = op["operands"]
935 assert (
936 pCount == 2 and cCount == 0
937 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000938 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000939 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
Jeremy Johnson32bf9012024-03-20 16:32:23 +0000940 data_range = None # Use default
941 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE):
942 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100943 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
944 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100945 if add:
946 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
947 else:
948 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
949
950 # Work out the saturation limits
951 max_i32 = (1 << 31) - 1
952 min_i32 = -(1 << 31)
953 max_arr = np.full(shapeList[1], max_i32)
954 min_arr = np.full(shapeList[1], min_i32)
955
956 # Find how much values exceed the maximum/minimums
957 sat_max_arr = np.maximum(res_arr - max_arr, 0)
958 sat_min_arr = np.minimum(res_arr - min_arr, 0)
959
960 if not add:
961 # Swap saturation values and negate values as we need to perform opposite operations
962 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
963
964 # Create new array of unsaturated values by clipping values as needed
965 b_unsat_arr = b_arr
966 if (sat_max_arr != 0).any():
967 # Clip values that cause saturation
968 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
969 # Reduce axes in unsaturated tensor to match original tensor
970 for axis, dim in enumerate(b_arr.shape):
971 if dim != b_unsat_arr.shape[axis]:
972 assert (
973 dim == 1
974 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
975 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
976
977 if (sat_min_arr != 0).any():
978 # Clip values that cause saturation
979 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
980 # Reduce axes in unsaturated tensor to match original tensor
981 for axis, dim in enumerate(b_arr.shape):
982 if dim != b_unsat_arr.shape[axis]:
983 assert (
984 dim == 1
985 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
986 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
987
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000988 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100989 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
990 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000991 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100992 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
993 )
994
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000995 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100996 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000997 # ERROR_IF or floating point test
998 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100999 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001000 )
1001 if data_range:
1002 argsDict["data_range"] = data_range
1003
1004 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001005 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001006 )
1007
1008 @staticmethod
1009 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001010 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001011 ):
1012 if dtypeList[0] in (
1013 DType.INT32,
1014 DType.INT16,
1015 DType.INT8,
1016 ):
1017 # Limit input tensors with cond_if_binary or while_loop to stop
1018 # saturation of add/sub ops with int32 and keep all logical shift
1019 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001020 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 pCount, cCount = op["operands"]
1022 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001023 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001024 for idx, shape in enumerate(shapeList[:]):
1025 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001026 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001028 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001029 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001030 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001031 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1032 )
1033 pRemain -= 1
1034 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001035 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036 testGen.ser.addConst(shape, dtypeList[idx], arr)
1037 )
1038
Jeremy Johnson587cc842024-02-08 11:45:44 +00001039 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001040 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001041 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001042 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001043 )
1044
1045 @staticmethod
1046 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001047 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001048 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001049 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001050 pCount, cCount = op["operands"]
1051 # Force value of operand[1] to be within [0, num_bits]
1052 assert (
1053 pCount == 2 and cCount == 0
1054 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1055
Jeremy Johnson587cc842024-02-08 11:45:44 +00001056 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 for idx, shape in enumerate(shapeList[:]):
1058 if idx == 1:
1059 if dtypeList[idx] == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001060 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001061 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001062 arr = np.int32(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001063 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001064 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001065 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001066 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001067 else:
1068 raise Exception("OpArithmeticRightShift: invalid input dtype")
1069 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001070 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001071 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001072
Jeremy Johnson587cc842024-02-08 11:45:44 +00001073 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001074
1075 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001076 def tvgReshape(
1077 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1078 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001079 dtypeList[1] = DType.SHAPE
1080 shapeList[1] = [len(argsDict["new_shape"])]
1081 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1082 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1083
1084 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001085 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001086 )
1087
1088 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001089 def tvgRescale(
1090 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1091 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001092 scale32 = argsDict["scale"]
1093 multiplier_arr = argsDict["multiplier"]
1094 shift_arr = argsDict["shift"]
1095
1096 if scale32:
1097 dtypeList[1] = DType.INT32
1098 else:
1099 dtypeList[1] = DType.INT16
1100 shapeList[1] = [len(multiplier_arr)]
1101 dtypeList[2] = DType.INT8
1102 shapeList[2] = [len(shift_arr)]
1103 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1104 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1105
1106 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001107 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001108 )
1109
1110 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001111 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001112 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1113 pad_values = argsDict["pad"].flatten()
1114 dtypeList[1] = DType.SHAPE
1115 shapeList[1] = [len(pad_values)]
1116 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1117 argsDict["fixed_data"] = [None, pad_values]
1118
1119 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001120 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001121 )
1122
1123 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001124 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001125 dtypeList[1] = DType.SHAPE
1126 shapeList[1] = [len(argsDict["start"])]
1127 dtypeList[2] = DType.SHAPE
1128 shapeList[2] = [len(argsDict["size"])]
1129 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1130 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1131
1132 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001133 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001134 )
1135
1136 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001137 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001138 dtypeList[1] = DType.SHAPE
1139 shapeList[1] = [len(argsDict["multiples"])]
1140 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1141
1142 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001143 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001144 )
1145
1146 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001147 def tvgSelect(
1148 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1149 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001150 # Set datatype of condition tensor to boolean
1151 dtypeList[0] = DType.BOOL
1152
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001153 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001154 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 )
1156
1157 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001158 def tvgIntDiv(
1159 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1160 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001162 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001163 pCount, cCount = op["operands"]
1164 assert (
1165 pCount == 2 and cCount == 0
1166 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1167
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001168 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001169
1170 # Two invalid cases for Op.INTDIV:
1171 # 1. divisor == 0
1172 # 2. dividend == -(1<<31) and divisor == -1
1173 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001174 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1175 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001176
1177 if (divisor_arr == 0).any():
1178 continue
1179
1180 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1181 continue
1182
1183 break
1184
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001185 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001186 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1187 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001188 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001189 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1190 )
1191
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001192 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001193 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001194 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001195 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001196 )
1197
Jeremy Johnson30476252023-11-20 16:15:30 +00001198 # Set the MUL data range to the square root of the largest value
1199 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001200 TVG_FLOAT_HIGH_VALUE_MUL = {
1201 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1202 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1203 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1204 }
1205
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001206 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001207 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001208 if error_name is not None or dtypeList[0] in (
1209 DType.FP16,
1210 DType.BF16,
1211 DType.FP32,
1212 ):
1213 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001214 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001215 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001216 )
1217 if data_range:
1218 argsDict["data_range"] = data_range
1219
Jeremy Johnson0a042992024-02-28 13:20:05 +00001220 if dtypeList[0] != DType.SHAPE:
1221 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1222 dtypeList[2] = DType.INT8
1223 shapeList[2] = [1]
1224 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1225 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1226
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001227 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001228 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001229 )
1230 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001231 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001232 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001233
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001234 tens_ser_list = []
1235
1236 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001237 if dtypeList[0] == DType.SHAPE:
1238 shift = 0
1239 else:
1240 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001241 if dtypeList[0] == DType.INT8:
1242 num_bits = 8
1243 elif dtypeList[0] == DType.INT16:
1244 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001245 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001246 num_bits = 32
1247 elif error_name == ErrorIf.WrongInputType:
1248 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001249 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001250 raise Exception(
1251 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1252 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001253
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001254 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001255 if dtypeList[idx] == DType.SHAPE:
1256 low = testGen.args.tensor_shape_range[0]
1257 high = testGen.args.tensor_shape_range[1]
1258 else:
1259 low = -(2 ** (num_bits - 1))
1260 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001261
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001262 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1263 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001264
1265 i = 0
1266 while True:
1267
1268 a_arr_64 = a_arr.astype(np.int64)
1269 b_arr_64 = b_arr.astype(np.int64)
1270
1271 if shift > 0:
1272 rounding = 1 << (shift - 1)
1273 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001274 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001275 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001276
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001277 if (result_arr > -(2**31)).all() and (
1278 result_arr <= ((2**31) - 1)
1279 ).all():
1280 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001281
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001282 i = i + 1
1283 a_arr = a_arr // 2
1284 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285
Won Jeon74342e52024-01-09 00:34:40 +00001286 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001287 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001288 tens_ser_list.append(
1289 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1290 )
1291 tens_ser_list.append(
1292 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1293 )
1294 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001295 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001296 tens_ser_list.append(
1297 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1298 )
1299 tens_ser_list.append(
1300 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1301 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001302 tens_ser_list.append(
1303 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1304 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001305
1306 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001307
1308 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001309 def tvgConcat(
1310 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1311 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001312 count = len(shapeList) - testGen.args.num_const_inputs_concat
1313 if count < 1:
1314 count = 1
1315 if testGen.args.num_const_inputs_concat == 0:
1316 count = len(shapeList)
1317
Won Jeon74342e52024-01-09 00:34:40 +00001318 op = testGen.TOSA_OP_LIST[opName]
1319 if op["op"] == Op.CONCAT_SHAPE:
1320 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001321 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001322 else:
1323 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001324 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001325 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001326
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001327 # Override default pCount/cCount for operator
1328 argsDict["p_count"] = count
1329 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001330
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001331 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001332 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001333 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001334
1335 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001336 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001337 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001338 ):
1339 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001340 pCount, cCount = op["operands"]
1341 assert (
1342 pCount == 2 and cCount == 0
1343 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001344 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1345 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001346 tens_ser_list = []
1347 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001348 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1349 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001350 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001351 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1352 )
1353
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001354 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001355
1356 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001357 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001358 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1359 # Integer
1360 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001361 pCount, cCount = op["operands"]
1362 assert (
1363 pCount == 2 and cCount == 0
1364 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001365
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001366 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1367 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001368
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001369 # Using random numbers means that it will be very unlikely that
1370 # there are any matching (equal) values, therefore force that
1371 # there are twice the number of matching values as the tensor rank
1372 for num in range(0, len(shapeList[0]) * 2):
1373 a_index = []
1374 b_index = []
1375 # Choose an index in each axis for the whole shape
1376 for axis in range(0, len(shapeList[0])):
1377 # Index can be up to the largest dimension in both shapes
1378 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001379 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001380 )
1381 # Reduce the index down to a shape's dim for broadcasting
1382 a_index.append(min(shapeList[0][axis] - 1, index))
1383 b_index.append(min(shapeList[1][axis] - 1, index))
1384
1385 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1386
Jeremy Johnsona0150012023-11-15 15:52:06 +00001387 tens_ser_list = []
1388 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001389 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1390 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001391 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001392 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1393 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001394 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001395 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001396 # ERROR_IF or floating point test
1397 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001398 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001399 )
1400
1401 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001402 def tvgReduceSum(
1403 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1404 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001405 dtype = dtypeList[0]
1406 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001407 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001408 pCount, cCount = op["operands"]
1409 assert (
1410 pCount == 1 and cCount == 0
1411 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1412 # Limit values so that the sum cannot exceed the range of an int32 during
1413 # summation of any axis
1414 range_val = int((1 << 31) / max(shapeList[0]))
1415 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001416 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001417 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001418 tens_ser_list = []
1419 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001420 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001421 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001422 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001423 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001424 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001425 if (
1426 error_name is None
1427 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1428 ):
1429 # Limit ranges for (non error & non compliance) tests by using
1430 # values that can be summed on any axis to not hit infinity
1431 highval_lookup = {
1432 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1433 / max(shapeList[0])
1434 }
1435 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001436 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001437 )
1438 assert data_range is not None
1439 argsDict["data_range"] = data_range
1440
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001441 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001442 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001443 )
1444
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001445 @staticmethod
1446 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001447 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001448 ):
1449 dtype = dtypeList[0]
1450 if error_name is None:
1451 # Limit ranges for (non error) tests by using
1452 # values that can be multiplied on any axis to not hit infinity
1453 highval_lookup = {
1454 dtype: math.pow(
1455 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1456 1 / max(shapeList[0]),
1457 )
1458 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001459 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001460 assert data_range is not None
1461 argsDict["data_range"] = data_range
1462
1463 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001464 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001465 )
1466
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001467 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001468 def tvgResize(
1469 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1470 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001471 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001472 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001473 dtypeList[0],
1474 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1475 )
1476 if data_range:
1477 argsDict["data_range"] = data_range
1478 # Needed for compliance
1479 argsDict["max_abs_value"] = data_range[1]
1480
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001481 scale_values = argsDict["scale"]
1482 offset_values = argsDict["offset"]
1483 border_values = argsDict["border"]
1484 dtypeList[1] = DType.SHAPE
1485 dtypeList[2] = DType.SHAPE
1486 dtypeList[3] = DType.SHAPE
1487 shapeList[1] = [len(scale_values)]
1488 shapeList[2] = [len(offset_values)]
1489 shapeList[3] = [len(border_values)]
1490 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1491
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001492 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001493 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001494 )
1495
Jeremy Johnson30476252023-11-20 16:15:30 +00001496 # Set the POW exponent high data range
1497 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1498 DType.FP32: 10.0,
1499 DType.FP16: 10.0,
1500 DType.BF16: 10.0,
1501 }
1502 # POW highest base value (within a safe margin of error) that can be raised
1503 # to +ve exponent that doesn't become Infinity
1504 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1505 DType.FP32: math.floor(
1506 math.pow(
1507 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1508 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1509 )
1510 ),
1511 DType.FP16: math.floor(
1512 math.pow(
1513 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1514 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1515 )
1516 ),
1517 DType.BF16: math.floor(
1518 math.pow(
1519 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1520 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1521 )
1522 ),
1523 }
1524 # POW lowest base value (within a safe margin of error) that can be raised
1525 # to -ve exponent that doesn't become Infinity
1526 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1527 DType.FP32: math.ceil(
1528 math.pow(
1529 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1530 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1531 )
1532 * 1000
1533 )
1534 / 1000,
1535 DType.FP16: math.ceil(
1536 math.pow(
1537 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1538 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1539 )
1540 * 1000
1541 )
1542 / 1000,
1543 DType.BF16: math.ceil(
1544 math.pow(
1545 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1546 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1547 )
1548 * 1000
1549 )
1550 / 1000,
1551 }
1552
1553 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001554 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001555 if error_name is not None:
1556 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001557 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001558 )
1559 dtype = dtypeList[0]
1560 # Different ranges for POW
1561 test_set = argsDict["s"]
1562 if test_set == 0:
1563 # Positive base with fractional exponent
1564 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001565 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001566 dtype,
1567 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1568 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1569 )
1570 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001571 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001572 )
1573 exp_round = False
1574 else:
1575 # Integer exponent
1576 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001577 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001578 )
1579 exp_round = True
1580 if test_set == 1:
1581 # Positive base
1582 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001583 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001584 dtype,
1585 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1586 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1587 )
1588 else:
1589 assert test_set == 2
1590 # Negative base
1591 # Supply new look up tables with negative values
1592 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001593 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001594 dtype,
1595 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1596 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1597 )
1598
1599 data_range_list = (
1600 {
1601 "range": base_range,
1602 },
1603 {
1604 "range": exp_range,
1605 "round": exp_round,
1606 },
1607 )
1608 argsDict["data_range_list"] = data_range_list
1609 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001610 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001611 )
1612
1613 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001614 def tvgLogRsqrt(
1615 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1616 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001617 # LOG & RSQRT data range from lowest expressible positive number to
1618 # largest to avoid NaNs
1619 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001620 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001621 dtypeList[0],
1622 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1623 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1624 )
1625 if data_range:
1626 argsDict["data_range"] = data_range
1627
1628 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001629 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001630 )
1631
1632 # Set the EXP data range to the log of the largest to smallest values
1633 # to avoid infinities or making the result zero
1634 TVG_FLOAT_HIGH_VALUE_EXP = {
1635 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1636 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1637 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1638 }
1639 TVG_FLOAT_LOW_VALUE_EXP = {
1640 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1641 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1642 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1643 }
1644
1645 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001646 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001647 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001648 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001649 dtypeList[0],
1650 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1651 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1652 )
1653 if data_range:
1654 argsDict["data_range"] = data_range
1655
1656 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001657 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001658 )
1659
1660 @staticmethod
1661 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001662 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001663 ):
1664 dtype = dtypeList[0]
1665 if (
1666 error_name is None
1667 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001668 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001669 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001670 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001671 # Limit ranges for (non error & non compliance) FP tests by using
1672 # values that can be multiplied on any axis to not hit infinity/NaN
1673 IC = shapeList[0][1]
1674 highval_lookup = {
1675 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1676 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001677 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001678 assert data_range is not None
1679 argsDict["data_range"] = data_range
1680
1681 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001682 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001683 )
1684
Jeremy Johnson708da822023-11-15 16:25:45 +00001685 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001686 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001687 in_dtype = dtypeList[0]
1688 out_dtype = argsDict["out_type"]
1689 # Create look up to limit input tensor to output type maximums to avoid
1690 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001691 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001692 highval_lookup = {in_dtype: out_range[1]}
1693 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001694 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001695 in_dtype,
1696 highval_lookup,
1697 )
1698
1699 assert data_range is not None
1700 argsDict["data_range"] = data_range
1701
1702 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001703 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001704 )
1705
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001706 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001707 def tvgGather(
1708 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1709 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001710 K = shapeList[0][1]
1711
1712 # Fix the type of the indices tensor
1713 dtypeList[1] = DType.INT32
1714
1715 dtype = dtypeList[0]
1716 if not gtu.dtypeIsSupportedByCompliance(dtype):
1717 # Test unsupported by data generator
1718 op = testGen.TOSA_OP_LIST[opName]
1719 pCount, cCount = op["operands"]
1720 assert (
1721 pCount == 2 and cCount == 0
1722 ), "Op.GATHER must have 2 placeholders, 0 consts"
1723
1724 tens_ser_list = []
1725 for idx, shape in enumerate(shapeList):
1726 dtype = dtypeList[idx]
1727 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001728 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001729 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1730 else:
1731 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001732 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001733 # To match old functionality - create indices as CONST
1734 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1735
1736 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1737
1738 else:
1739 # ERROR_IF or floating point test
1740 # Use inclusive values upto index K for indices tensor
1741 data_range_list = (
1742 {"range": None},
1743 {"range": (0, K - 1)},
1744 )
1745 argsDict["data_range_list"] = data_range_list
1746
1747 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001748 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001749 )
1750
1751 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001752 def tvgScatter(
1753 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1754 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001755 K = shapeList[0][1]
1756 W = shapeList[2][1]
1757
1758 # Work out an indices tensor here with data that doesn't exceed the
1759 # dimension K of the values_in tensor and does NOT repeat the same K
1760 # location as needed by the spec:
1761 # "It is not permitted to repeat the same output index within a single
1762 # SCATTER operation and so each output index occurs at most once."
1763 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1764
1765 # Fix the type of the indices tensor
1766 dtypeList[1] = DType.INT32
1767
1768 dtype = dtypeList[0]
1769 if not gtu.dtypeIsSupportedByCompliance(dtype):
1770 # Test unsupported by data generator
1771 op = testGen.TOSA_OP_LIST[opName]
1772 pCount, cCount = op["operands"]
1773 assert (
1774 pCount == 3 and cCount == 0
1775 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1776
1777 tens_ser_list = []
1778 for idx, shape in enumerate(shapeList):
1779 dtype = dtypeList[idx]
1780 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001781 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001782 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1783 else:
1784 # Create the indices array
1785 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1786 arr = []
1787 for n in range(shape[0]):
1788 # Get a shuffled list of output indices (0 to K-1) and
1789 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001790 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001791 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1792 # To match old functionality - create indices as CONST
1793 tens_ser_list.append(
1794 testGen.ser.addConst(shape, dtype, indices_arr)
1795 )
1796
1797 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1798
1799 else:
1800 # ERROR_IF or floating point test
1801 # Use inclusive values upto index K for indices tensor
1802 data_range_list = (
1803 {"range": None},
1804 {"range": (0, K - 1)},
1805 {"range": None},
1806 )
1807 argsDict["data_range_list"] = data_range_list
1808
1809 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001810 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001811 )
1812
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001813
1814class TosaArgGen:
1815 """Argument generators create exhaustive or random lists of attributes for
1816 operators that take attributes or other parameters.
1817
1818 The return value is a list of (descriptive_name, [arglist]) tuples where
1819 the descriptive_name is appended to the test name and the arglist is expanded
1820 as arguments to the operator build function.
1821 """
1822
1823 def __init__(self):
1824 pass
1825
1826 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001827 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001828 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001829 if (
1830 error_name is None
1831 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1832 and gtu.dtypeIsSupportedByCompliance(dtype)
1833 ):
Tai Ly60dc48c2024-03-08 22:19:41 +00001834 if gtu.dtypeIsFloat(dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001835 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1836 else:
1837 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1838 else:
1839 # Error test or No data generator types listed - assume random
1840 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1841
1842 # Expand arg list with other data generator types
1843 new_arg_list = []
1844 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001845 for arg_str, args_dict in arg_list:
evacha019c96eef2024-02-07 11:21:55 +00001846
1847 if dg_type == gtu.DataGenType.FULL_RANGE:
1848 tensor_size = gtu.product(shapeList[0])
1849 if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1850 # Large enough tensor data size for full range, add a single test
1851 num_test_sets = 0
1852 else:
1853 # Not enough data size for full range of values, revert to random numbers
1854 dg_type = gtu.DataGenType.PSEUDO_RANDOM
1855
Jeremy Johnson1271c442023-09-05 11:39:26 +01001856 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001857 if error_name is None:
1858 num_test_sets = (
1859 args_dict["num_test_sets"]
1860 if "num_test_sets" in args_dict
1861 else 0
1862 )
1863 else:
evacha019c96eef2024-02-07 11:21:55 +00001864 # Add single test for pseudo random
Jeremy Johnson30476252023-11-20 16:15:30 +00001865 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001866
1867 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1868 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001869 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001870 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001871 shape_info = (
1872 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1873 if "shape" in args_dict
1874 else ""
1875 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001876 logger.info(
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00001877 f"Skipping {opName}{shape_info} {gtu.DTYPE_ATTRIBUTES[dtype]['json']} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001878 )
1879 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001880 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001881 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001882 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001883
Jeremy Johnson30476252023-11-20 16:15:30 +00001884 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1885
1886 if num_test_sets > 0:
1887 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001888 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
1889 set_args_dict = args_dict.copy()
1890 set_args_dict["s"] = s
1891 set_args_dict["dg_type"] = dg_type
1892 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001893 else:
1894 # Default is a single test
evacha019c96eef2024-02-07 11:21:55 +00001895 new_args_dict = args_dict.copy()
1896 new_args_dict["dg_type"] = dg_type
1897 new_arg_list.append((arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001898
1899 return new_arg_list
1900
1901 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001902 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001903 """A trivial argument generator for operators that don't take any
1904 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001905 arg_list = TosaArgGen._add_data_generators(
1906 testGen,
1907 opName,
evacha019c96eef2024-02-07 11:21:55 +00001908 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001909 dtype,
1910 [("", {})],
1911 error_name,
1912 )
1913 # Return list of tuples: (arg_str, args_dict)
1914 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001915
1916 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001917 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001918 """Pow operator needs different test sets to cover random numbers
1919 without creating NaNs or Infs"""
1920 arg_list = TosaArgGen._add_data_generators(
1921 testGen,
1922 opName,
evacha019c96eef2024-02-07 11:21:55 +00001923 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001924 dtype,
1925 [("", {"num_test_sets": 3})],
1926 error_name,
1927 )
1928 # Return list of tuples: (arg_str, args_dict)
1929 return arg_list
1930
1931 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001932 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001933 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001934 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001935 shape = shapeList[0]
1936
1937 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001938 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001939 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001940 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001941 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001942 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001943 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001944 # Create tests for each dimension
1945 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001946
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001947 opid = testGen.TOSA_OP_LIST[opName]["op"]
1948
1949 for a in axes:
1950 args_dict = {"axis": int(a)}
1951 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001952 output_shape = shape.copy()
1953 if error_name is None:
1954 # It only matters that we calculate the dot_products correctly
1955 # for non error_if tests as they should never be run
1956 output_shape[a] = 1
1957 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001958 args_dict["shape"] = shape
1959 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1960 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1961
1962 arg_list.append(("axis{}".format(a), args_dict))
1963
1964 arg_list = TosaArgGen._add_data_generators(
1965 testGen,
1966 opName,
evacha019c96eef2024-02-07 11:21:55 +00001967 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001968 dtype,
1969 arg_list,
1970 error_name,
1971 )
1972 # Return list of tuples: (arg_str, args_dict)
1973 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001974
1975 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001976 def _calculate_sparsity(num_tests, sparsity_factor):
1977 sparsity = num_tests // sparsity_factor + 1
1978 # If there are only a small number of tests, just select them all
1979 if sparsity < 13:
1980 sparsity = 1
1981 # To get a variety of parameter combinations sparsity should not be a
1982 # multiple of 2, 3 or 5
1983 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1984 sparsity += 1
1985 return sparsity
1986
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001987 # Maximum number of error_if variants to produce
Jeremy Johnson87460262024-03-25 09:46:02 +00001988 MAX_TESTS_ERROR_IFS = 3
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001989
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001990 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001991 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001992 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001993 arg_list = []
1994
Jeremy Johnson0c716862023-04-13 17:18:19 +01001995 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001996 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001997 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001998 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001999
Tai Lyf36f2562024-03-14 16:21:29 +00002000 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2001
2002 if error_name == ErrorIf.WrongAccumulatorType:
2003 accum_dtypes = (
2004 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2005 )
James Ward8b390432022-08-12 20:48:56 +01002006
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002007 # For op type checks
2008 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002009
2010 # Check the rank
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002011 rank = 5 if op["op"] == Op.CONV3D else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002012 if error_name != ErrorIf.WrongRank:
2013 assert len(ifm_shape) == rank
2014 assert len(filter_shape) == rank
2015
Jeremy Johnson0c716862023-04-13 17:18:19 +01002016 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002017 k_rank = rank - 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002018 k_pos = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002019 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002020 # compliance size - KS
2021 k_size = gtu.product(k_shape)
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002022 if not op["op"] == Op.DEPTHWISE_CONV2D:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002023 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002024
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002025 def get_conv_output_info(p, s, d, fix_up_padding=False):
2026 # Work out remainders and output dimensions with an
2027 # option to adjust paddings to create a valid operation
2028 nonlocal ifm_shape, k_shape, error_name, k_rank
2029 if fix_up_padding:
2030 p = list(p) # Make paddings editable
2031 outputs_no_stride = []
2032 remainders = []
2033 outputs = []
2034 for index in range(k_rank):
2035 pad_offset = index * 2
2036 fixed = False
2037 # Fix up pad values to produce valid conv2d
2038 while not fixed:
2039 # Output dimension without being adjusted for stride
2040 output_no_stride = (
2041 ifm_shape[index + 1]
2042 - 1
2043 + p[pad_offset]
2044 + p[pad_offset + 1]
2045 - (k_shape[index] - 1) * d[index]
2046 )
2047 # Tensor left over after applying striding
2048 remainder = output_no_stride % s[index]
2049 if not fix_up_padding:
2050 # Just want remainders and outputs
2051 break
2052 if output_no_stride <= 0:
2053 p[pad_offset + 1] += abs(output_no_stride) + 1
2054 continue
2055 if error_name == ErrorIf.ConvOutputShapeNonInteger:
2056 if remainder:
2057 # Conditions to trigger the test
2058 fixed = True
2059 else:
2060 p[pad_offset + 1] += 1
2061 else:
2062 if remainder:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002063 # Stride will be negative for StrideSmallerOne
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002064 assert remainder > 0 or (
2065 error_name == ErrorIf.StrideSmallerOne and remainder < 0
2066 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002067 p[pad_offset + 1] += abs(remainder)
2068 else:
2069 fixed = True
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002070 outputs_no_stride.append(output_no_stride)
2071 remainders.append(remainder)
2072 # Output dimension taking in to account stride
2073 outputs.append((output_no_stride // s[index]) + 1)
2074
2075 if fix_up_padding:
2076 p = tuple(p) # Make the paddings read-only
2077 assert min(outputs_no_stride) > 0, "Fix up did not work!"
2078 return p, remainders, outputs, outputs_no_stride
2079
2080 # Only fix up padding for conv2d and float types currently
2081 fix_up_padding = gtu.dtypeIsFloat(dtypes[0]) and op["op"] == Op.CONV2D
2082 # Allow any size of output dimension
2083 max_dim_size = None
2084 # Include all tests by default
2085 sparsity = 1
2086
2087 # Work out padding, strides and dilation ranges depending on
2088 # error and arguments
2089 if error_name in (
2090 ErrorIf.PadSmallerZero,
2091 ErrorIf.StrideSmallerOne,
2092 ErrorIf.DilationSmallerOne,
2093 ):
2094 # Use specific invalid value(s)
2095 if error_name == ErrorIf.PadSmallerZero:
2096 # Create negative paddings but with positive opposite paddings
2097 neg_pad = rng.choice(range(-5, 0))
2098 p_vals = [neg_pad, abs(neg_pad)]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002099 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002100 p_vals = [0, 0]
2101 if error_name == ErrorIf.StrideSmallerOne:
2102 # Can't use stride=0, as it is used to derive output shape, as a divisor
2103 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002104 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002105 s_vals = [1]
2106 if error_name == ErrorIf.DilationSmallerOne:
2107 d_vals = [rng.choice(range(-5, 1))]
2108 else:
2109 d_vals = [1]
2110 paddings = {tuple(p_vals) * k_rank}
2111 strides = {tuple(s_vals) * k_rank}
2112 dilations = {tuple(d_vals) * k_rank}
2113
2114 fix_up_padding = True # Need to fix up paddings to be valid
2115
2116 elif testGen.args.level8k and error_name is None:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002117 # Only test 8k levels boundaries
2118 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2119 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2120 bigPadding = bigKernel
2121
2122 dilation_shape = [1] * k_rank
2123 pad_shape = [0] * k_rank * 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002124 if op["op"] == Op.CONV3D:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002125 # Small stride apart from for big kernel (see below) to keep
2126 # tensor size/calculation small
2127 stride_shape = [1] * k_rank
2128 for idx in range(k_rank):
2129 pad_offset = idx * 2
2130 if k_shape[idx] == bigKernel:
2131 # Padding shape needs to account for tensor shape
2132 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2133 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2134 # Big stride to reduce output size
2135 stride_shape[idx] = bigKernel
2136 else:
2137 # Account for kernel size
2138 pad_shape[pad_offset] = k_shape[idx] - 1
2139 else:
2140 # Always have a large stride with extra padding and dilation to keep
2141 # tensor calculation reasonable
2142 stride_shape = [bigKernel] * k_rank
2143 for idx in range(k_rank):
2144 # Dilation shape must account for kernel size
2145 dilation_shape[idx] = bigKernel // k_shape[idx]
2146 # Padding shape needs to accommodate tensor/kernel & dilation
2147 pad_offset = idx * 2
2148 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2149 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2150
2151 strides = {tuple(stride_shape)}
2152 dilations = {tuple(dilation_shape)}
2153 paddings = {tuple(pad_shape)}
2154 # Create a limit for the output dimensions size
2155 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2156
2157 # Currently allow all combinations that are reasonable size
2158 sparsity = 1
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002159 else:
2160 # Generate comprehensive argument lists
2161 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
2162 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
2163 # Stride must be greater than 1 to force non-integer error
2164 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
2165 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
2166 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002167
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002168 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2169 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
2170
2171 if error_name is None and testGen.args.oversize:
2172 # add some oversize argument values
2173 if max(ifm_shape) < 64:
2174 bigPadding = 9
2175 paddings.update(
2176 {
2177 x
2178 for x in itertools.product(
2179 *([[0, bigPadding]] * (k_rank * 2))
2180 )
2181 }
2182 )
2183 bigStride = 8
2184 strides.update(
2185 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2186 )
2187 bigDilation = 7
2188 dilations.update(
2189 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2190 )
2191
2192 if error_name is None:
2193 # There are too many parameter combinations, so generate them sparsely,
2194 sparsity_factor = 120
2195 sparsity = TosaArgGen._calculate_sparsity(
2196 len(paddings) * len(strides) * len(dilations), sparsity_factor
2197 )
2198
2199 # Run through all the argument options creating valid test cases
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002200 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002201 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002202 for a in accum_dtypes:
2203 for s in sorted(list(strides)):
2204 for p in sorted(list(paddings)):
2205 for d in sorted(list(dilations)):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002206 if more_tests and (n % sparsity == 0):
2207 (
2208 p,
2209 remainders,
2210 outputs,
2211 outputs_no_stride,
2212 ) = get_conv_output_info(p, s, d, fix_up_padding)
2213 # Following is like checking each dimension N:
2214 # (ifm_shape[N+1] - 1 + p[N*2] + p[N*2+1]) > d[N] * (k_shape[N] - 1)
2215 if min(outputs_no_stride) <= 0:
2216 # Not a valid operation
2217 n += 1 # Increment count of tests
2218 continue
Tai Lyf36f2562024-03-14 16:21:29 +00002219
2220 if (
2221 # the parameters must produce integer exact output
2222 error_name != ErrorIf.ConvOutputShapeNonInteger
2223 and max(remainders) == 0
2224 ) or (
2225 error_name == ErrorIf.ConvOutputShapeNonInteger
2226 and max(remainders) > 0
2227 ):
2228 if (
2229 max_dim_size is not None
2230 and max(outputs) >= max_dim_size
2231 ):
2232 # Test will consume too much memory - skip it
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002233 logger.debug(
2234 "agConv: Convolution output too big - skipped"
2235 )
Tai Lyf36f2562024-03-14 16:21:29 +00002236 continue
2237
2238 # Compliance - number of dot product calculations
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002239 if op["op"] == Op.DEPTHWISE_CONV2D:
Tai Lyf36f2562024-03-14 16:21:29 +00002240 # N*OH*OW*C*M
2241 dots = gtu.product(
2242 (ifm_shape[0], *outputs, *filter_shape[2:])
2243 )
2244 else:
2245 # N*OH*OW*OC or N*OD*OH*OW*OC
2246 dots = gtu.product(
2247 (ifm_shape[0], *outputs, filter_shape[0])
2248 )
2249 args_dict = {
2250 "acc_type": a,
2251 "stride": s,
2252 "pad": p,
2253 "dilation": d,
2254 "kernel": k_shape,
2255 "ks": k_size,
2256 "dot_products": dots,
2257 "shape": ifm_shape,
2258 }
2259
2260 # Support for larger values than 9 needs different delimiter
2261 delim = "" if max(s + p + d) <= 9 else "x"
2262 arg_list.append(
2263 (
2264 "acc{}_st{}_pad{}_dilat{}".format(
2265 testGen.typeStr(a),
2266 delim.join([str(x) for x in s]),
2267 delim.join([str(x) for x in p]),
2268 delim.join([str(x) for x in d]),
2269 ),
2270 args_dict,
2271 )
2272 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002273 if (
2274 error_name
Jeremy Johnson87460262024-03-25 09:46:02 +00002275 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002276 ):
2277 # Found enough errors
2278 logger.debug(
2279 f"Skipping creating more conv error tests for {error_name}"
2280 )
2281 more_tests = False
Tai Lyf36f2562024-03-14 16:21:29 +00002282 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002283
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002284 arg_list = TosaArgGen._add_data_generators(
2285 testGen,
2286 opName,
evacha019c96eef2024-02-07 11:21:55 +00002287 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002288 dtypes[0],
2289 arg_list,
2290 error_name,
2291 )
2292 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002293 return arg_list
2294
2295 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002296 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002297
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002298 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002299 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002300
2301 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002302 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002303 elif error_name == ErrorIf.WrongInputType:
2304 # Pick some potentially correct output dtype if input type is incorrect
2305 accum_dtype = DType.INT32
2306 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002307 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002308
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002309 # Set up compliance info
2310 args_dict = {
2311 "acc_type": accum_dtype,
2312 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2313 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2314 "shape": shapeList[0],
2315 }
2316
2317 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2318
2319 arg_list = TosaArgGen._add_data_generators(
2320 testGen,
2321 opName,
evacha019c96eef2024-02-07 11:21:55 +00002322 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002323 input_dtype,
2324 arg_list,
2325 error_name,
2326 )
2327 # Return list of tuples: (arg_str, args_dict)
2328 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002329
2330 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002331 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002332 # Get valid accumulate type(s)
2333 if dtype == DType.INT8:
2334 accum_dtypes = [DType.INT32]
2335 elif dtype == DType.INT16:
2336 accum_dtypes = [DType.INT48]
2337 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002338 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002339 elif dtype == DType.BF16:
2340 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002341 elif dtype == DType.FP32:
2342 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002343 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2344 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002345 elif error_name is None:
2346 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2347
2348 if error_name == ErrorIf.WrongOutputType:
2349 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002350 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002351 elif error_name == ErrorIf.WrongInputType:
2352 # Pick some potentially correct output dtype if input type is incorrect
2353 accum_dtypes = [DType.INT32]
2354
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002355 # Set up compliance info
2356 args_dict = {
2357 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2358 # Set dot_products = N*H*W
2359 "dot_products": gtu.product(
2360 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2361 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002362 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002363 }
2364
2365 # Create arg tuple of string and dict
2366 arg_list = []
2367 for a in accum_dtypes:
2368 d = args_dict.copy()
2369 d["acc_type"] = a
2370 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002371
2372 arg_list = TosaArgGen._add_data_generators(
2373 testGen,
2374 opName,
evacha019c96eef2024-02-07 11:21:55 +00002375 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002376 dtype,
2377 arg_list,
2378 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002379 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002380 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002381 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002382
2383 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002384 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002385 arg_list = []
2386
Jeremy Johnson0c716862023-04-13 17:18:19 +01002387 if testGen.args.level8k and error_name is not None:
2388 # Don't produce negative large tests
2389 return arg_list
2390
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002391 ifm_shape = shapeList[0]
2392 filter_shape = shapeList[1]
2393
Tai Lyf36f2562024-03-14 16:21:29 +00002394 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2395
2396 if error_name == ErrorIf.WrongAccumulatorType:
2397 accum_dtypes = (
2398 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2399 )
James Ward8b390432022-08-12 20:48:56 +01002400
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002401 # Must be rank 4
2402 if error_name != ErrorIf.WrongRank:
2403 assert len(ifm_shape) == 4
2404 assert len(filter_shape) == 4
2405
Jeremy Johnson0c716862023-04-13 17:18:19 +01002406 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002407 # compliance size - KS
2408 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002409
Jeremy Johnson0c716862023-04-13 17:18:19 +01002410 if not testGen.args.level8k:
2411 # Generate comprehensive argument lists
2412 # - except for named errors, which use specific invalid value(s)
2413 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2414 if error_name == ErrorIf.PadLargerEqualKernel:
2415 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002416 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002417 else:
2418 p_vals = [
2419 x
2420 for x in range(
2421 smallest_padding_size, testGen.args.max_conv_padding + 1
2422 )
2423 ]
2424 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2425 if error_name == ErrorIf.StrideSmallerOne:
2426 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002427 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002428 else:
2429 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2430 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002431
Jeremy Johnson0c716862023-04-13 17:18:19 +01002432 if not error_name and testGen.args.oversize:
2433 # add some oversize argument values
2434 if max(ifm_shape) < 64:
2435 bigPadding = 9
2436 paddings.update(
2437 {
2438 x
2439 for x in itertools.product(
2440 *([[smallest_padding_size, bigPadding]] * 4)
2441 )
2442 }
2443 )
2444 bigStride = 8
2445 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2446
2447 # There are too many parameter combinations, so generate them sparsely,
2448 # very sparse for negative tests
2449 sparsity_factor = 2 if error_name else 10
2450 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2451 # If there are only a small number of tests, just select them all
2452 if sparsity < 13:
2453 sparsity = 1
2454 # To get a variety of parameter combinations sparsity should not be a
2455 # multiple of 2, 3 or 5
2456 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2457 sparsity += 1
2458 else:
2459 # Only test 8k levels boundaries
2460 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2461 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2462 bigPadding = bigKernel
2463
2464 pad_shape = [0] * (len(k_shape) * 2)
2465 stride_shape = [1] * len(k_shape)
2466 # The point at which input dimension combined with the stride will
2467 # create large output sizes!
2468 LARGE_SIZE = 2
2469 for idx in range(len(k_shape)):
2470 pad_offset = idx * 2
2471 if k_shape[idx] == bigKernel:
2472 # Set large stride
2473 stride_shape[idx] = bigKernel
2474 # Use negative output padding to reduce shape size
2475 pad_shape[pad_offset] = -(bigPadding - 1)
2476 if ifm_shape[idx + 1] > LARGE_SIZE:
2477 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2478 else:
2479 # The other dimension should be the bigKernel
2480 alt_idx = 1 - idx
2481 if (
2482 k_shape[alt_idx] == bigKernel
2483 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2484 ):
2485 # As the input is small, the large stride won't
2486 # affect the output so we can add some padding
2487 pad_shape[pad_offset + 1] = bigPadding
2488
2489 strides = {tuple(stride_shape)}
2490 paddings = {tuple(pad_shape)}
2491
2492 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002493 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002494
2495 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002496 for a in accum_dtypes:
2497 for s in sorted(list(strides)):
2498 for p in sorted(list(paddings)):
2499 if n % sparsity == 0:
2500 # Determine the output shape
2501 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2502 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2503 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002504
Tai Lyf36f2562024-03-14 16:21:29 +00002505 # N*OH*OW*OC
2506 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2507 args_dict = {
2508 "acc_type": a,
2509 "stride": s,
2510 "pad": p,
2511 "kernel": k_shape,
2512 "ks": k_size,
2513 "dot_products": dots,
2514 "shape": ifm_shape,
2515 "out_shape": os,
2516 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002517
Tai Lyf36f2562024-03-14 16:21:29 +00002518 # Support for larger values than 9 needs different delimiter
2519 delim = "" if max(s + p) <= 9 else "x"
2520 arg_list.append(
2521 (
2522 "acc{}_st{}_pad{}_os{}".format(
2523 testGen.typeStr(a),
2524 delim.join([str(x) for x in s]),
2525 delim.join([str(x) for x in p]),
2526 "x".join([str(x) for x in os]),
2527 ),
2528 args_dict,
2529 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002530 )
Tai Lyf36f2562024-03-14 16:21:29 +00002531 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002532
Jeremy Johnson95a67102024-01-10 14:16:39 +00002533 arg_list = TosaArgGen._add_data_generators(
2534 testGen,
2535 opName,
evacha019c96eef2024-02-07 11:21:55 +00002536 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002537 dtypes[0],
2538 arg_list,
2539 error_name,
2540 )
2541 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542 return arg_list
2543
2544 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002545 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002546 rank = len(shapeList[0])
2547
Jeremy Johnson30a36842024-03-27 15:04:07 +00002548 if error_name is None and testGen.args.oversize:
2549 pad_values = [6, 7, 10, 13]
2550 elif error_name == ErrorIf.PadSmallerZero:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002551 pad_values = [x for x in range(-2, 0)]
Jeremy Johnson30a36842024-03-27 15:04:07 +00002552 else:
2553 # Exhaustively test combinations of padding on each side of each dimension
2554 # - the range of padding values is defined by pad_min and pad_max
2555 pad_min, pad_max = 0, 1
2556 pad_values = [x for x in range(pad_min, pad_max + 1)]
2557
2558 # Calculate pad combinations
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002559 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2560 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2561
2562 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002563 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002564 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002565 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002566 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002567 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002568 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002569 assert error_name == ErrorIf.WrongInputType
2570 pad_const_int = 0
2571 pad_const_fp = 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002572
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002573 list_shape_pad_values = list(shape_pad_values)
2574 # If we are producing tests for rank 6 or greater use sparsity
2575 if len(list_shape_pad_values) > 1024:
2576 sparsity_factor = 2 if error_name else 120
2577 sparsity = TosaArgGen._calculate_sparsity(
2578 len(list_shape_pad_values), sparsity_factor
2579 )
2580 else:
2581 sparsity = 1
2582
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002583 # Build arg list
2584 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002585 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002586 paddings = list(paddings)
2587 args_valid = True
2588
2589 if error_name == ErrorIf.PadSmallerZero:
2590 # Prevent negative output shapes while ensuring still testing for negative padding
2591 for i in range(rank):
2592 dim_after_padding = (
2593 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2594 )
2595 if dim_after_padding < 1:
2596 paddings[i] = (0, 0)
2597 if all([p > -1 for p in paddings[i]]):
2598 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002599 if args_valid and n % sparsity == 0:
Jeremy Johnson30a36842024-03-27 15:04:07 +00002600 # Work out name
2601 pad_list = []
James Ward8b390432022-08-12 20:48:56 +01002602 for r in range(rank):
Jeremy Johnson30a36842024-03-27 15:04:07 +00002603 pad_list.extend(paddings[r])
2604
2605 delim = "" if max(pad_list) <= 9 else "x"
2606 name = "pad{}".format(delim.join([str(x) for x in pad_list]))
2607
2608 args_dict = {
2609 "pad": np.array(paddings),
2610 "pad_const_int": pad_const_int,
2611 "pad_const_fp": pad_const_fp,
2612 }
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002613 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002614
2615 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002616 logger.debug(
2617 f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
2618 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002619
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002620 arg_list = TosaArgGen._add_data_generators(
2621 testGen,
2622 opName,
evacha019c96eef2024-02-07 11:21:55 +00002623 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002624 dtype,
2625 arg_list,
2626 error_name,
2627 )
2628
2629 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002630 return arg_list
2631
2632 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002633 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002634 arg_list = []
2635
2636 shape = shapeList[0]
2637 if error_name != ErrorIf.WrongRank:
2638 assert len(shape) == 4
2639
Jeremy Johnson0c716862023-04-13 17:18:19 +01002640 test_level8k = testGen.args.level8k and error_name is None
2641
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002642 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002643 startKernel = 2
2644 startPad = 0
2645 if not test_level8k:
2646 # Generate comprehensive argument lists
2647 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2648 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2649 # Stride must be greater than 1 to force non-integer error
2650 s_vals = [
2651 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2652 ]
2653 strides = {x for x in itertools.product(*([s_vals] * 2))}
2654 k_vals = [
2655 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2656 ]
2657 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2658 max_dim_size = None
2659 else:
2660 # Only test 8k levels
2661 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2662 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2663 strides = {(1, bigStride), (bigStride, 4)}
2664 kernels = {(1, bigKernel), (bigKernel, 3)}
2665 paddings = set()
2666 for s in sorted(list(strides)):
2667 for k in sorted(list(kernels)):
2668 padding = []
2669 for idx in range(len(k)):
2670 total_padding = s[idx] - shape[idx + 1] + k[idx]
2671 while total_padding < 0:
2672 # Must meet: shape + padding > kernel
2673 total_padding += s[idx]
2674 if total_padding < k[idx]:
2675 padding.extend([0, total_padding])
2676 else:
2677 # Note this may produce padding >= k[idx] which is not
2678 # allowed - but will be ignored in the creation loop below
2679 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2680 paddings.add(tuple(padding))
2681 # Create a limit for the output dimensions size
2682 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002683
James Ward8b390432022-08-12 20:48:56 +01002684 if opName == "max_pool2d":
2685 accum_dtypes = [None] # max_pool has no accumulate dtype
2686 elif dtype == DType.INT8 or dtype == DType.INT16:
2687 accum_dtypes = [DType.INT32]
2688 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002689 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002690 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002691 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002692 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2693 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002694 elif error_name is None:
2695 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2696 else:
2697 # Set to something for the ErrorIf case which has
2698 # incorrect input data-type
2699 accum_dtypes = [DType.INT32]
2700
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002701 if error_name == ErrorIf.WrongAccumulatorType:
2702 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2703
Jeremy Johnson0c716862023-04-13 17:18:19 +01002704 if not test_level8k:
2705 if testGen.args.oversize:
2706 # add some oversize argument values
2707 bigStride = 7
2708 bigKernel = 9
2709 strides.update(
2710 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002711 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002712 kernels.update(
2713 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2714 )
2715 if max(shape) < 64:
2716 # padding must be less than the kernel size
2717 bigPadding = bigKernel - 1
2718 paddings.update(
2719 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2720 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002721
Jeremy Johnson87460262024-03-25 09:46:02 +00002722 if error_name:
2723 # Cycle through all error_if tests but we only keep the first few
2724 sparsity = 1
2725 else:
2726 # There are too many parameter combinations, so generate them sparsely
2727 sparsity_factor = 500
2728 sparsity = (
2729 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2730 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002731 else:
2732 # We have already limited test output combinations for 8k tests
2733 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002734
James Ward8b390432022-08-12 20:48:56 +01002735 arg_str = (
2736 "acc{}_st{}_kern{}_pad{}"
2737 if accum_dtypes[0] is not None
2738 else "st{}_kern{}_pad{}"
2739 )
2740
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002741 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002742 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002743 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002744
2745 # Support for larger values than 9 needs different delimiter
2746 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002747 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002748 delim.join([str(x) for x in stride]),
2749 delim.join([str(x) for x in kern]),
2750 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002751 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002752 args_dict = {
2753 "stride": stride,
2754 "pad": pad,
2755 "kernel": kern,
2756 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002757 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002758 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2759 }
James Ward8b390432022-08-12 20:48:56 +01002760
2761 if accum is not None:
2762 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002763 args_dict["acc_type"] = accum
2764 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002765
Jeremy Johnson87460262024-03-25 09:46:02 +00002766 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002767 n = 0
James Ward8b390432022-08-12 20:48:56 +01002768 for a in accum_dtypes:
2769 for s in sorted(list(strides)):
2770 for p in sorted(list(paddings)):
2771 for k in sorted(list(kernels)):
2772 if error_name in [
2773 ErrorIf.StrideSmallerOne,
2774 ErrorIf.KernelSmallerOne,
2775 ErrorIf.PadSmallerZero,
2776 ErrorIf.PadLargerEqualKernel,
2777 ]:
2778 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002779 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002780 )
James Ward8b390432022-08-12 20:48:56 +01002781 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002782 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002783 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002784 )
James Ward8b390432022-08-12 20:48:56 +01002785 elif (
Jeremy Johnson87460262024-03-25 09:46:02 +00002786 more_tests
2787 and n % sparsity == 0
James Ward8b390432022-08-12 20:48:56 +01002788 # padding must not exceed the kernel size
2789 and p[0] < k[0]
2790 and p[1] < k[0]
2791 and p[2] < k[1]
2792 and p[3] < k[1]
2793 # the padded shape must exceed the kernel size
2794 and (shape[1] + p[0] + p[1]) > k[0]
2795 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002796 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002797 partial_h = shape[1] + p[0] + p[1] - k[0]
2798 partial_w = shape[2] + p[2] + p[3] - k[1]
2799 remainder_h = partial_h % s[0]
2800 remainder_w = partial_w % s[1]
2801 output_h = partial_h // s[0] + 1
2802 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002803 logger.debug(
2804 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2805 )
James Ward8b390432022-08-12 20:48:56 +01002806 if (
2807 # the parameters must produce integer exact output
2808 error_name != ErrorIf.PoolingOutputShapeNonInteger
2809 and remainder_h == 0
2810 and remainder_w == 0
2811 ) or (
2812 error_name == ErrorIf.PoolingOutputShapeNonInteger
2813 and (remainder_h != 0 or remainder_w != 0)
2814 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002815 if (
2816 max_dim_size is not None
2817 and max(output_h, output_w) > max_dim_size
2818 ):
2819 # Test will consume too much memory - skip it
2820 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002821 # Dot products = N*OH*OW*C
2822 dp = gtu.product(
2823 (shape[0], output_h, output_w, shape[3])
2824 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002825 arg_list.append(
2826 get_arg_list_element(a, s, p, k, dp, shape)
2827 )
Jeremy Johnson87460262024-03-25 09:46:02 +00002828 if (
2829 error_name
2830 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
2831 ):
2832 # Found enough errors
2833 logger.debug(
2834 f"Skipping creating more pooling error tests for {error_name}"
2835 )
2836 more_tests = False
2837
James Ward8b390432022-08-12 20:48:56 +01002838 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002839
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002840 # Now add data generator types
2841 arg_list = TosaArgGen._add_data_generators(
2842 testGen,
2843 opName,
evacha019c96eef2024-02-07 11:21:55 +00002844 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002845 dtype,
2846 arg_list,
2847 error_name,
2848 )
2849
2850 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002851 return arg_list
2852
2853 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002854 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002855 arg_list = []
2856
2857 # Enumerate the output types here
2858 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002859 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002860 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002861 dtypeList = [
2862 DType.BOOL,
2863 DType.INT16,
2864 DType.INT32,
2865 DType.FP16,
2866 DType.BF16,
2867 DType.FP32,
2868 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002869 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002870 dtypeList = [
2871 DType.BOOL,
2872 DType.INT8,
2873 DType.INT32,
2874 DType.FP16,
2875 DType.BF16,
2876 DType.FP32,
2877 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002878 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002879 dtypeList = [
2880 DType.BOOL,
2881 DType.INT8,
2882 DType.INT16,
2883 DType.FP16,
2884 DType.BF16,
2885 DType.FP32,
2886 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002887 elif inDtype == DType.BOOL:
2888 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002889 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002890 dtypeList = [
2891 DType.INT8,
2892 DType.INT16,
2893 DType.INT32,
2894 DType.FP32,
2895 DType.FP8E4M3,
2896 DType.FP8E5M2,
2897 ]
James Ward24dbc422022-10-19 12:20:31 +01002898 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002899 dtypeList = [
2900 DType.INT8,
2901 DType.INT16,
2902 DType.INT32,
2903 DType.FP32,
2904 DType.FP8E4M3,
2905 DType.FP8E5M2,
2906 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002907 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002908 dtypeList = [
2909 DType.INT8,
2910 DType.INT16,
2911 DType.INT32,
2912 DType.FP16,
2913 DType.BF16,
2914 DType.FP8E4M3,
2915 DType.FP8E5M2,
2916 ]
2917 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2918 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002919 elif error_name == ErrorIf.WrongInputType:
2920 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002921 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002922 else:
2923 raise Exception("Unexpected input dtype: {}".format(inDtype))
2924
2925 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002926 arg_list.append(
2927 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2928 )
2929
2930 # Now add data generator types
2931 arg_list = TosaArgGen._add_data_generators(
2932 testGen,
2933 opName,
evacha019c96eef2024-02-07 11:21:55 +00002934 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002935 dtype,
2936 arg_list,
2937 error_name,
2938 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002939
2940 return arg_list
2941
2942 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002943 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002944 arg_list = []
2945
2946 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002947 for outDtype in [
2948 DType.UINT8,
2949 DType.INT8,
2950 DType.INT16,
2951 DType.INT32,
2952 DType.UINT16,
2953 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002954 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002955 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002956 and error_name == ErrorIf.OutputZeroPointNotZero
2957 ):
2958 continue
2959 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002960 outDtype != DType.UINT16
2961 and error_name == ErrorIf.U16OutputZeroPointNotValid
2962 ) or (
2963 inDtype != DType.UINT16
2964 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002965 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002966 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002967 continue
2968 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002969 inDtype == DType.UINT8
2970 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002971 and error_name != ErrorIf.WrongOutputType
2972 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002973 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2974 continue
2975 if (
2976 inDtype not in [DType.INT8, DType.INT16]
2977 and outDtype == DType.UINT8
2978 and error_name != ErrorIf.WrongOutputType
2979 ):
2980 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2981 continue
2982 if (
2983 inDtype == DType.UINT16
2984 and outDtype != DType.INT16
2985 and error_name != ErrorIf.WrongOutputType
2986 ):
2987 # The only output dtype for UINT16 is INT16, skip all others
2988 continue
2989 if (
2990 inDtype != DType.INT16
2991 and outDtype == DType.UINT16
2992 and error_name != ErrorIf.WrongOutputType
2993 ):
2994 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002995 continue
2996 if (
2997 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002998 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002999 ):
3000 continue
3001
3002 for scale32 in [False, True]:
3003 if error_name == ErrorIf.ScaleTrue and not scale32:
3004 continue
3005 elif error_name == ErrorIf.ScaleNotTrue and scale32:
3006 continue
3007 for double_round in [False, True]:
3008 if error_name == ErrorIf.ScaleNotTrue and not double_round:
3009 continue
3010 for per_channel in [False, True]:
3011
3012 if (
3013 inDtype == DType.INT48
3014 and scale32
3015 and error_name != ErrorIf.ScaleTrue
3016 ):
3017 # Illegal condition. Must be scale32=False
3018 continue
3019 if (
3020 double_round
3021 and not scale32
3022 and error_name != ErrorIf.ScaleNotTrue
3023 ):
3024 # Illegal condition. ERROR_IF(!scale32 && double_round)
3025 continue
3026
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003027 if per_channel:
3028 nc = shapeList[0][-1]
3029 else:
3030 nc = 1
3031
3032 in_type_width = gtu.dtypeWidth(inDtype)
3033 out_type_width = gtu.dtypeWidth(outDtype)
3034
3035 # Calculate scale based on:
3036 # scale = a *(2^output_width)/(2^input_width))
3037
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003038 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003039 scale_arr = a * np.float32(
3040 (1 << out_type_width) / (1 << in_type_width)
3041 )
3042
3043 if scale32:
3044 # Cap the scaling at 2^31 - 1 for scale32
3045 scale_arr = np.clip(
3046 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
3047 )
3048 else:
3049 # Cap the scaling at 2^15 - 1 for scale16
3050 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3051
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003052 logger.debug(
3053 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
3054 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003055
3056 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3057 shift_arr = np.int32(np.zeros(shape=[nc]))
3058 for i in range(nc):
3059 (
3060 multiplier_arr[i],
3061 shift_arr[i],
3062 ) = TosaQuantGen.computeMultiplierAndShift(
3063 scale_arr[i], scale32
3064 )
3065
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003066 arg_list.append(
3067 (
3068 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01003069 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003070 int(scale32),
3071 int(double_round),
3072 int(per_channel),
3073 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00003074 {
3075 "output_dtype": outDtype,
3076 "scale": scale32,
3077 "double_round": double_round,
3078 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003079 "multiplier": multiplier_arr,
3080 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003081 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003082 )
3083 )
3084
Jeremy Johnson587cc842024-02-08 11:45:44 +00003085 arg_list = TosaArgGen._add_data_generators(
3086 testGen,
3087 opName,
evacha019c96eef2024-02-07 11:21:55 +00003088 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003089 inDtype,
3090 arg_list,
3091 error_name,
3092 )
3093 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003094 return arg_list
3095
3096 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003097 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003098 arg_list = []
3099
3100 if dtype is DType.INT32:
3101 for p in range(testGen.args.num_rand_permutations):
3102
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003103 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003104 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003105 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003106 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003107
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003108 arg_list = TosaArgGen._add_data_generators(
3109 testGen,
3110 opName,
evacha019c96eef2024-02-07 11:21:55 +00003111 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003112 dtype,
3113 arg_list,
3114 error_name,
3115 )
3116 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003117 return arg_list
3118
3119 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003120 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003121 arg_list = []
3122
Jeremy Johnson587cc842024-02-08 11:45:44 +00003123 for round in (True, False):
3124 args_dict = {
3125 "round": round,
3126 }
3127 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003128
Jeremy Johnson587cc842024-02-08 11:45:44 +00003129 arg_list = TosaArgGen._add_data_generators(
3130 testGen,
3131 opName,
evacha019c96eef2024-02-07 11:21:55 +00003132 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003133 dtype,
3134 arg_list,
3135 error_name,
3136 )
3137 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003138 return arg_list
3139
Luke Hutton57287132023-02-06 14:54:18 +00003140 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003141 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003142 arg_list = []
3143
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003144 shape = shapeList[0]
3145 dot_products = gtu.product(shape)
3146 ks = 2 * shape[1] * shape[2] # 2*H*W
3147 for inverse in (True, False):
3148 args_dict = {
3149 "dot_products": dot_products,
3150 "shape": shape,
3151 "ks": ks,
3152 "acc_type": dtype,
3153 "inverse": inverse,
3154 }
3155 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003156
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003157 arg_list = TosaArgGen._add_data_generators(
3158 testGen,
3159 opName,
evacha019c96eef2024-02-07 11:21:55 +00003160 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003161 dtype,
3162 arg_list,
3163 error_name,
3164 )
3165 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003166 return arg_list
3167
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003169 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003170 arg_list = []
3171
3172 shape = shapeList[0]
3173 dot_products = gtu.product(shape)
3174 ks = shape[1] * shape[2] # H*W
3175 args_dict = {
3176 "dot_products": dot_products,
3177 "shape": shape,
3178 "ks": ks,
3179 "acc_type": dtype,
3180 }
3181 arg_list.append(("", args_dict))
3182
3183 arg_list = TosaArgGen._add_data_generators(
3184 testGen,
3185 opName,
evacha019c96eef2024-02-07 11:21:55 +00003186 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003187 dtype,
3188 arg_list,
3189 error_name,
3190 )
3191 # Return list of tuples: (arg_str, args_dict)
3192 return arg_list
3193
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003194 # Helper function for reshape. Gets some factors of a larger number.
3195 @staticmethod
3196 def getFactors(val, start=1):
3197 factors = []
3198
3199 for i in range(start, int(np.sqrt(val)) + 1):
3200 if (val % i) == 0:
3201 factors.append(i)
3202
3203 return factors
3204
3205 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003206 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003207 arg_list = []
3208
3209 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003210 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003211 factors = TosaArgGen.getFactors(totalElements)
3212
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003213 # Find new shapes up to the number of permutations asked for
3214 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003215 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003216 # Rank from 1 to MAX_TENSOR_RANK
3217 newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003218 if len(factors) < newRank:
3219 continue
3220
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003221 # escape_counter limits the generation of new shapes to a reasonable time
3222 for escape_counter in range(100):
3223
3224 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003225 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003226 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003227 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003228 for i in range(1, newRank):
3229 # pick rank-1 factors
3230 newShape.append(shuffledFactors[0])
3231 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003232 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003233 TosaArgGen.getFactors(remainingElements)
3234 )
3235 newShape.append(remainingElements)
3236
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003237 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003238 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003239 for name, args_dict in arg_list:
3240 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003241 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003242 break
3243
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003244 if not duplicate:
3245 outShape = "x".join([str(x) for x in newShape])
3246 arg_list.append(
3247 (
3248 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3249 {"new_shape": newShape},
3250 )
3251 )
3252 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003253 break
3254
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003255 # Now add data generator types
3256 arg_list = TosaArgGen._add_data_generators(
3257 testGen,
3258 opName,
evacha019c96eef2024-02-07 11:21:55 +00003259 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003260 dtype,
3261 arg_list,
3262 error_name,
3263 )
3264
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003265 return arg_list
3266
3267 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003268 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003269 arg_list = []
3270
3271 ifm_shape = shapeList[0]
3272
3273 if error_name == ErrorIf.IndexOutsideBounds:
3274 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3275 incorrect_small_index = range(-len(ifm_shape), 0)
3276 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3277 permutations.extend(
3278 [p for p in itertools.permutations(incorrect_small_index)]
3279 )
3280 elif error_name == ErrorIf.IndexUsedTwice:
3281 # Create list with a duplicated index
3282 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003283 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3285 permutations = [p for p in itertools.permutations(perm_range)]
3286
3287 else:
3288 # Get all permutations
3289 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3290
3291 # Limit to possible permutations from shape dimension or argument setting
3292 limit = min(len(permutations), testGen.args.num_rand_permutations)
3293
3294 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003295 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003296
3297 # Create list of required amount of permutations
3298 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003299 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003300 for p in range(limit)
3301 ]
evacha0198477222024-01-26 12:25:32 +00003302 # Now add data generator types
3303 arg_list = TosaArgGen._add_data_generators(
3304 testGen,
3305 opName,
evacha019c96eef2024-02-07 11:21:55 +00003306 shapeList,
evacha0198477222024-01-26 12:25:32 +00003307 dtype,
3308 arg_list,
3309 error_name,
3310 )
3311 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 return arg_list
3313
3314 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003315 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003316 arg_list = []
3317
3318 ifm_shape = shapeList[0]
3319 rank = len(ifm_shape)
3320
3321 for p in range(testGen.args.num_rand_permutations):
3322 start = []
3323 size = []
3324
3325 valid = True
3326
3327 for i in range(rank):
3328 if ifm_shape[i] > 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003329 start.append(rng.randInt(0, ifm_shape[i]))
3330 size.append(rng.randInt(0, ifm_shape[i] - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003331
3332 # Invalid slice size?
3333 if size[i] == 0:
3334 valid = False
3335 else:
3336 start.append(0)
3337 size.append(1)
3338
3339 if valid:
3340 # If ERROR_IF test required then incorrect start, size will be returned
3341 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003342 rng, error_name, ifm_shape, start, size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003343 )
evacha017f7d4252024-01-24 12:08:09 +00003344 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3345 # Now add data generator types
3346 arg_list = TosaArgGen._add_data_generators(
3347 testGen,
3348 opName,
evacha019c96eef2024-02-07 11:21:55 +00003349 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003350 dtype,
3351 arg_list,
3352 error_name,
3353 )
3354 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003355 return arg_list
3356
3357 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003358 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003359 arg_list = []
3360
3361 ifm_shape = shapeList[0]
3362 rank = len(ifm_shape)
3363
3364 for p in range(testGen.args.num_rand_permutations):
3365
3366 # Pick a few random, but small multiple values
3367 # because otherwise this has a tendency to generate
3368 # enormous tensors
3369 multiples = []
3370 for i in range(rank):
3371 if ifm_shape[i] > 1000:
3372 # Multiple of 1 if ifm_shape dimension is large to reduce
3373 # tensor size
3374 multiples.append(1)
3375 elif max(ifm_shape) > 1000:
3376 multiples.append(2)
3377 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003378 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003379 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003380
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003381 # Now add data generator types
3382 arg_list = TosaArgGen._add_data_generators(
3383 testGen,
3384 opName,
evacha019c96eef2024-02-07 11:21:55 +00003385 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003386 dtype,
3387 arg_list,
3388 error_name,
3389 )
3390 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003391 return arg_list
3392
3393 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003394 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003395 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003396 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003397
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003398 def get_aspect_ratio_resize_params():
3399 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003400 aspect_ratio = rng.choice(common_aspect_ratios)
3401 invert = rng.choice((False, True))
3402 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003403
3404 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3405 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3406 scale_y_d = scale_x_d = 1
3407 offset_x = offset_y = 0
3408
3409 if letterbox:
3410 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003411 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003412 border_x = 0
3413 else:
3414 # Pillarboxing
3415 border_y = 0
3416 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003417 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003418
3419 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3420 offset = (offset_y, offset_x)
3421 border = (border_y, border_x)
3422
3423 return scale, offset, border
3424
3425 def get_upscale_downscale_params():
3426 valid_params = False
3427 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003428 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003429
3430 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003431 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003432
3433 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003434 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003435 scale_x_d = scale_y_d = 1
3436 scale_x_n = scale_y_n = (
3437 1 << shift if origin_sampling else 2 << shift
3438 )
3439 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3440 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3441 else:
3442 scale_x_n = 1
3443 scale_y_n = 1
3444
3445 # Return list of valid scale_*_d values (max value 4) given input dim shape
3446 def get_valid_denom(ifm_dim):
3447 return [x for x in range(1, 5) if ifm_dim % x == 1]
3448
3449 # Generate list of valid downscale values and choose one randomly
3450 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3451 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3452
3453 if not valid_scale_y_ds and not valid_scale_x_ds:
3454 # Bad parameters, skip
3455 continue
3456
3457 if not valid_scale_y_ds:
3458 scale_y_d = 1
3459 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003460 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003461
3462 if not valid_scale_x_ds:
3463 scale_x_d = 1
3464 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003465 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003466
3467 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003468 offset_y = rng.randInt(0, 16 * scale_y_n)
3469 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003470 valid_params = True
3471
3472 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3473 offset = (offset_y, offset_x)
3474 border = (border_y, border_x)
3475 return scale, offset, border
3476
3477 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003478 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3479 scale = scale_n / scale_d
3480 if scale > max_scale:
3481 factor = scale / max_scale
3482 new_scale_d = math.ceil(scale_d * factor)
3483 assert scale_n / new_scale_d <= max_scale
3484 scale_d = new_scale_d
3485 return scale_d
3486
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003487 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003488 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3489 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003490
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003491 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3492 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003493
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003494 scale_y_d = fix_scale_to_max_scale(
3495 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3496 )
3497 scale_x_d = fix_scale_to_max_scale(
3498 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3499 )
3500
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003501 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003502 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3503 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3504 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3505 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003506
3507 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3508 offset = (offset_y, offset_x)
3509 border = (border_y, border_x)
3510 return scale, offset, border
3511
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003512 def get_level_8k_params():
3513 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003514 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003515 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3516 )
3517 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3518 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003519 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003520 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003521 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003522 if switch:
3523 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3524 else:
3525 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3526
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003527 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3528 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003529 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003530 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3531 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003532 border = (border_y, border_x)
3533 return scale, offset, border
3534
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003535 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003536 # Exclude illegal {mode, type} configurations. Pick legal output types
3537 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3538 outputDTypeList = [DType.INT8]
3539 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3540 outputDTypeList = [DType.INT16]
3541 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3542 outputDTypeList = [DType.INT32]
3543 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3544 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003545 elif dtype == DType.FP16:
3546 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003547 elif dtype == DType.BF16:
3548 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003549 elif dtype == DType.FP32:
3550 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003551 elif dtype == DType.FP8E4M3:
3552 outputDTypeList = [DType.FP8E4M3]
3553 elif dtype == DType.FP8E5M2:
3554 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003555 elif error_name == ErrorIf.WrongInputType:
3556 # If an incorrect input type is used then we set a 'correct'
3557 # output type to avoid other errors
3558 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3559 else:
3560 continue
3561
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003562 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3563
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003564 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003565 perm = 0
3566 while perm < testGen.args.num_rand_permutations:
3567 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003568 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003569 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003570 (
3571 get_rand_params,
3572 get_upscale_downscale_params,
3573 get_aspect_ratio_resize_params,
3574 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003575 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003576 scale, offset, border = _rnd_param_fn()
3577 else:
3578 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003579
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003580 # Expand params for bounds-checking
3581 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3582 (offset_y, offset_x) = offset
3583 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003584
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003585 # Make sure output dimensions OH and OW are integers
3586 partial_output_y = (
3587 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3588 )
3589 partial_output_x = (
3590 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3591 )
3592 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003593 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003594 if (
3595 partial_output_y % scale_y_d == 0
3596 and partial_output_x % scale_x_d == 0
3597 ):
3598 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003599 if perm > 0:
3600 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003601 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003602 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003603 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003604 while partial_output_y % scale_y_d != 0:
3605 scale_y_d -= 1
3606 while partial_output_x % scale_x_d != 0:
3607 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003608 # Make sure we are still within max scaling
3609 if (
3610 scale_y_n / scale_y_d
3611 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3612 scale_x_n / scale_x_d
3613 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3614 # Skip the test as it is using too large a scaling factor
3615 if perm > 0:
3616 perm += 1
3617 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003618
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003619 output_y = partial_output_y // scale_y_d + 1
3620 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003621
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003622 if (
3623 output_y >= testGen.args.max_resize_output_dim
3624 or output_x >= testGen.args.max_resize_output_dim
3625 ) and error_name is None:
3626 # Skip positive test if output dim will be too high
3627 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003628 if not testGen.args.level8k or perm > 0:
3629 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003630 continue
3631
3632 if (
3633 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003634 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003635 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003636 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003637 ):
3638 # Output dimensions out of scope
3639 if error_name is not None and perm > 0:
3640 # As long as we have one ERROR_IF test, don't worry
3641 # about creating all the other permutations
3642 perm += 1
3643 continue
3644
3645 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3646 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003647 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003648 and output_y - scale_y_d < 1
3649 )
3650 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003651 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003652 and output_x - scale_x_d < 1
3653 )
3654 ):
3655 # Can't create a negative test with these params as it
3656 # will create invalid output size
3657 if perm > 0:
3658 perm += 1
3659 continue
3660
3661 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3662 offset = [offset_y, offset_x]
3663 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003664
3665 # Common for all data types
3666 if error_name is not None:
3667 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003668 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003669 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003670 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003671 outputDTypeNew,
3672 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003673 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003674 error_name,
3675 mode,
3676 dtype,
3677 shapeList,
3678 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003679 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003680 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003681 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003682 )
3683 else:
3684 outputDTypeNew = outputDType
3685
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003686 arg_to_append = (
3687 arg_str.format(
3688 "N" if mode == ResizeMode.NEAREST else "B",
3689 testGen.typeStr(outputDTypeNew),
3690 scale[0],
3691 scale[1],
3692 scale[2],
3693 scale[3],
3694 offset[0],
3695 offset[1],
3696 border[0],
3697 border[1],
3698 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003699 {
3700 "mode": mode,
3701 "scale": scale,
3702 "offset": offset,
3703 "border": border,
3704 "output_dtype": outputDTypeNew,
3705 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003706 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003707 if arg_to_append in arg_list:
3708 # Skip already generated test params
3709 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003710
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003711 # Valid permutation
3712 perm += 1
3713 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003714
3715 # Now add data generator types
3716 arg_list = TosaArgGen._add_data_generators(
3717 testGen,
3718 opName,
evacha019c96eef2024-02-07 11:21:55 +00003719 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003720 dtype,
3721 arg_list,
3722 error_name,
3723 )
3724 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003725 return arg_list
3726
3727 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003728 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003729 arg_list = []
3730
3731 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003732 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003733 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003734 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003735 # Make sure all slopes are within REQUIRE min/max 16-bit int
3736 for idx in range(len(table) - 1):
3737 slope = table[idx + 1] - table[idx]
3738 # Alter the next table entry to force the slope to be ok
3739 if slope > 32767:
3740 table[idx + 1] -= slope - 32767
3741 if slope < -32768:
3742 table[idx + 1] -= slope + 32768
3743 slope = table[idx + 1] - table[idx]
3744 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003745 arg_list.append(
3746 (
3747 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003748 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003749 )
3750 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003751 # Now add data generator types
3752 arg_list = TosaArgGen._add_data_generators(
3753 testGen,
3754 opName,
evacha019c96eef2024-02-07 11:21:55 +00003755 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003756 dtype,
3757 arg_list,
3758 error_name,
3759 )
3760 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003761 return arg_list
3762
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003763 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003764 # CondIf generates the condition values here.
3765 # Convert to tensors in the build function, along with the
3766 # then and else blocks
3767 arg_list = []
3768
3769 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003770 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003771
Jeremy Johnson587cc842024-02-08 11:45:44 +00003772 # Now add data generator types
3773 arg_list = TosaArgGen._add_data_generators(
3774 testGen,
3775 opName,
evacha019c96eef2024-02-07 11:21:55 +00003776 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003777 dtype,
3778 arg_list,
3779 error_name,
3780 )
3781 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003782 return arg_list
3783
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003784 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003785 # While loop: 0 iterations, 1, more than 1
3786 arg_list = []
3787
Jeremy Johnson587cc842024-02-08 11:45:44 +00003788 for iterations in [0, 1, 4]:
3789 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003790
Jeremy Johnson587cc842024-02-08 11:45:44 +00003791 # Now add data generator types
3792 arg_list = TosaArgGen._add_data_generators(
3793 testGen,
3794 opName,
evacha019c96eef2024-02-07 11:21:55 +00003795 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003796 dtype,
3797 arg_list,
3798 error_name,
3799 )
3800 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003801 return arg_list