blob: 3a859610415e3b053873a5ff6875520b685bf553 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
262 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100263 shape_list = []
264
265 # Choose one of the inputs to broadcast
266 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100267 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
268 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000269
Jeremy Johnson0a042992024-02-28 13:20:05 +0000270 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100271 shape_bcast = shape.copy()
272
Jerry Ge135c9552023-05-23 20:59:32 +0000273 # To test broadcasting, the chosen fuzz index dimension should not be 1
274 if shape_bcast[fuzz_idx] == 1:
275 shape_bcast[fuzz_idx] += 1
276
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100277 # If the chosen input, pick a random index to broadcast
278 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000279 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100280 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100281 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100282 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100283 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 )
285 if rank != 1:
286 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100287 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000289 elif error_name == ErrorIf.BroadcastShapesMismatch:
290 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 else:
292 shape_bcast[fuzz_idx] = 1
293
294 shape_list.append(shape_bcast)
295
296 return shape_list
297
298 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100299 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000300 pl, const = op["operands"]
301 num_shapes = pl + const
302 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100303 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000304 )
305
306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100307 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000308 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 shape_list = TosaTensorGen._get_broadcast_shapes(
310 testGen, rng, 2, rank, error_name
311 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000312 # Add a single dimension tensor for shift
313 shape_list.append([1])
314 return shape_list
315
316 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100317 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318 pl, const = op["operands"]
319
320 if error_name != ErrorIf.WrongRank:
321 assert rank == 4
322
323 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000325 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326
327 # Constrict the overall size of the shape when creating ERROR_IF tests
328 if error_name:
329 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
330 ifm_shape, max_dim=24, max_items=10000
331 )
332
333 # Get the filter height/width from the operator parameters
334 filter_hw = op["filter"]
335
336 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100337 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # The filter dimensions are OHWI
340 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
341
Jeremy Johnson5e36bde2024-03-14 16:56:10 +0000342 # The bias is OC or 1 if broadcastable
343 try:
344 if op["broadcastable_bias"]:
345 if rng.choice([True, False]):
346 ofm_depth = 1
347 except KeyError:
348 pass
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 bias_shape = np.asarray([ofm_depth])
350
351 return [ifm_shape, filter_shape, bias_shape]
352
353 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100354 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100355 pl, const = op["operands"]
356
357 if error_name != ErrorIf.WrongRank:
358 assert rank == 5
359
360 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100361 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000362 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # Constrict the overall size of the shape when creating ERROR_IF tests
365 if error_name:
366 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
367 ifm_shape, max_dim=24, max_items=10000
368 )
369
370 # Get the filter depth/height/width from the operator parameters
371 filter_dhw = op["filter"]
372
373 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100374 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100375
376 # The filter dimensions are ODHWI
377 filter_shape = np.asarray(
378 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
379 )
380
381 # The bias is OC
382 bias_shape = np.asarray([ofm_channel])
383
384 return [ifm_shape, filter_shape, bias_shape]
385
386 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100387 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100388 pl, const = op["operands"]
389
390 if error_name != ErrorIf.WrongRank:
391 assert rank == 4
392
393 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100394 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000395 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100396
397 # Constrict the overall size of the shape when creating ERROR_IF tests
398 if error_name:
399 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
400 ifm_shape, max_dim=24, max_items=10000
401 )
402
403 # Get the filter height/width from the operator parameters
404 filter_hw = op["filter"]
405
406 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100407 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100408
409 # The filter dimensions are OHWI
410 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
411
412 # The bias is OC
413 bias_shape = np.asarray([ofm_depth])
414
415 return [ifm_shape, filter_shape, bias_shape]
416
417 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100418 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100419 pl, const = op["operands"]
420
421 if error_name != ErrorIf.WrongRank:
422 assert rank == 4
423 assert pl == 1 and const == 2
424
425 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100426 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000427 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100428
429 # Constrict the overall size of the shape when creating ERROR_IF tests
430 if error_name:
431 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
432 ifm_shape, max_dim=24, max_items=10000
433 )
434
435 # Get the filter height/width from the operator parameters
436 # Filter is KH, HW, C, M
437 filter_hw = op["filter"]
438
439 # Generate a random OFM depth, but don't let it get too big because
440 # the output depth is M * C
441 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100442 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100443 ) + 1
444
445 # The filter dimensions are HWCM
446 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
447
448 # The bias is M * C
449 bias_shape = np.asarray([ifm_shape[3] * filter_m])
450
451 return [ifm_shape, filter_shape, bias_shape]
452
453 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100454 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000455 pl, const = op["operands"]
456
457 if error_name != ErrorIf.WrongRank:
458 assert rank == 3
459 assert pl == 2 and const == 0
460
461 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100462 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000463
464 # Select nearest lower power of two from input height and width
465 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
466 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
467
468 # Constrict the overall size of the shape when creating ERROR_IF tests
469 if error_name:
470 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
471
472 # Generate an invalid kernel that is not a power of two
473 if error_name == ErrorIf.KernelNotPowerOfTwo:
474 inc_h = 2 if ifm_shape[1] == 1 else 1
475 inc_w = 2 if ifm_shape[2] == 1 else 1
476 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100477 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000478 ifm_shape[1] += selected_inc[0]
479 ifm_shape[2] += selected_inc[1]
480
481 ifm_shape = testGen.constrictBatchSize(ifm_shape)
482
483 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
484 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100485 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000486 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100487 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000488 ifm_shapes[modify_shape][modify_dim] *= 2
489
490 return [ifm_shapes[0], ifm_shapes[1]]
491
492 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100493 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000494 pl, const = op["operands"]
495
496 if error_name != ErrorIf.WrongRank:
497 assert rank == 3
498 assert pl == 1 and const == 0
499
500 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100501 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000502
503 # Select nearest lower power of two from input height and width
504 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
505 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
506
507 # Constrict the overall size of the shape when creating ERROR_IF tests
508 if error_name:
509 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
510
511 # Generate an invalid kernel that is not a power of two
512 if error_name == ErrorIf.KernelNotPowerOfTwo:
513 # We must increment by 2 if current size is 1
514 inc_h = 2 if ifm_shape[1] == 1 else 1
515 inc_w = 2 if ifm_shape[2] == 1 else 1
516 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100517 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000518 ifm_shape[1] += selected_inc[0]
519 ifm_shape[2] += selected_inc[1]
520
James Ward30124a82023-02-02 14:56:33 +0000521 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000522
523 return [ifm_shape]
524
525 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527 pl, const = op["operands"]
528
529 if error_name != ErrorIf.WrongRank:
530 assert rank == 2
531
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100532 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
537
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100538 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100539 low=testGen.args.tensor_shape_range[0],
540 high=testGen.args.tensor_shape_range[1],
541 size=1,
542 )[0]
543 filter_shape = np.asarray([filter_oc, input_shape[1]])
544
545 bias_shape = np.asarray([filter_oc])
546
547 return [input_shape, filter_shape, bias_shape]
548
549 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100550 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100551 pl, const = op["operands"]
552
553 if error_name != ErrorIf.WrongRank:
554 assert rank == 3
555 assert pl == 2 and const == 0
556
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100557 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100558
559 # Constrict the overall size of the shape when creating ERROR_IF tests
560 if error_name:
561 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
562
563 # Get a random number for b_oc even if target shape is defined
564 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100565 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100566 low=testGen.args.tensor_shape_range[0],
567 high=testGen.args.tensor_shape_range[1],
568 size=1,
569 )
570 )[0]
571 # If N or H is large let b_oc be 1 to reduce output tensor size
572 if max(a_shape) > 1000:
573 b_oc = 1
574
575 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
576 return [a_shape, b_shape]
577
578 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100579 def tgConcat(testGen, rng, op, rank, error_name=None):
580 pl, const = op["operands"]
581 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100582
583 # Create extra tensors to concat.
584 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100585 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100586 shape_list = []
587 for i in range(pl + const + num_tensors):
588 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100589 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100590 wrongShape = shape.copy()
591
592 if remove and len(shape) > 1:
593 wrongShape = wrongShape[1:]
594 else:
595 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100596 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100597
598 shape_list.append(wrongShape)
599 else:
600 shape_list.append(shape.copy())
601
602 return shape_list
603
604 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100605 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100606 if error_name in [
607 ErrorIf.AxisSmallerZero,
608 ErrorIf.AxisLargerRank,
609 ErrorIf.ConcatInputRankMismatch,
610 ]:
611 return shapeList
612
613 # Split concat shape along axis to allow for multiple const inputs
614 # without making too many large tensors
615 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
616 # If axis can't be split we still need to invalidate other dimensions
617 if error_name == ErrorIf.ConcatInputDimMismatch:
618 for shape in shapeList[1:]:
619 # Negative test shapeLists are created individually for each test,
620 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100621 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100622 return shapeList
623
624 # Create copy of shape we are going to split (so we don't alter shapeList)
625 shape = shapeList[0].copy()
626 # Add original shape as first input
627 new_shapeList = [shape.copy()]
628 length_on_axis = shape[axis]
629 remaining_length = length_on_axis
630 for i in range(len(shapeList) - 2):
631 # Calculate split on axis and remaining value
632 split_shape_val = int(shape[axis] / 2)
633 remaining_length = remaining_length - split_shape_val
634
635 # Append new shape, and set remaining shape
636 shape[axis] = split_shape_val
637 new_shapeList.append(shape.copy())
638
639 # invalidate dimensions
640 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100641 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100642 else:
643 shape[axis] = remaining_length
644
645 if i == len(shapeList) - 3:
646 new_shapeList.append(shape.copy())
647
648 return new_shapeList
649
650
651class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100652 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100653
654 def __init__(self):
655 pass
656
Jeremy Johnson1271c442023-09-05 11:39:26 +0100657 class TVGInfo:
658 """Enhanced tensor values information including data gen dict."""
659
660 def __init__(self, tensorList, dataGenDict):
661 self.tensorList = tensorList
662 self.dataGenDict = dataGenDict
663
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100664 # Default high value for random numbers
665 TVG_FLOAT_HIGH_VALUE = {
666 DType.FP32: (1 << 128) - (1 << (127 - 23)),
667 DType.FP16: (1 << 16) - (1 << (15 - 10)),
668 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000669 DType.FP8E4M3: 448,
670 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100671 }
672
Jeremy Johnson30476252023-11-20 16:15:30 +0000673 # Default lowest normal values for random numbers
674 TVG_FLOAT_LOW_VALUE = {
675 DType.FP32: np.exp2(-126),
676 DType.FP16: np.exp2(-14),
677 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000678 DType.FP8E4M3: np.exp2(-9),
679 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000680 }
681
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100682 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100683 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000684 # Return a tuple of (low,high) data range values for the given data
685 # type using a combination of per operator table limits, data limits
686 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000687 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100688 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000689 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000690 if lowValueLookup is not None and dtype in lowValueLookup:
691 low_val = lowValueLookup[dtype]
692 else:
693 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000694 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000695 # respecting the default ranges if more/less than the low/high
696 # values
697 data_range = (
698 max(low_val, type_range[0]),
699 min(high_val, type_range[1]),
700 )
701 if data_range[0] > data_range[1]:
702 # Invalid data range from low to high created due to user
703 # constraints revert to using internal ranges as they are
704 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000705 logger.info(
706 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
707 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000708 data_range = (low_val, high_val)
709 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000710 return None
711
712 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100713 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100714 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100715 ):
716 # Variable inputs versus constants
717 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000718 if "p_count" in argsDict:
719 # Override for operators like CONCAT
720 pCount = argsDict["p_count"]
721 cCount = argsDict["c_count"]
722 assert pCount + cCount == len(
723 shapeList
724 ), "Placeholders & Constant tensors must match shapes list"
725
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000726 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100727
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100728 if (
729 error_name is not None
730 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100731 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100732 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000733 # Fall back to internal data gen when dealing with unsupported types or ops
734 data_range = argsDict["data_range"] if "data_range" in argsDict else None
735 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000736 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000737 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000738 if "data_range_list" in argsDict:
739 data_range = argsDict["data_range_list"][idx]["range"]
740 roundMode = (
741 "round" in argsDict["data_range_list"][idx]
742 and argsDict["data_range_list"][idx]["round"] is True
743 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000744 if data_range is not None and dtype not in (
745 DType.FP16,
746 DType.FP32,
747 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000748 DType.FP8E4M3,
749 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000750 ):
751 # Change from inclusive to exclusive range
752 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000753
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100754 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000755 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000756 if dtype == DType.SHAPE:
757 arr = np.int64(argsDict["fixed_data"][idx])
758 elif dtype == DType.INT8:
759 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000760 elif dtype == DType.INT16:
761 arr = np.int16(argsDict["fixed_data"][idx])
762 elif dtype == DType.INT32:
763 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000764 else:
765 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000766 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100767 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000768 if roundMode:
769 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000770 if idx < pCount:
771 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
772 else:
773 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100774
Jeremy Johnson1271c442023-09-05 11:39:26 +0100775 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
776
777 # Create data generator meta-data
778 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100779 tens_data = {
780 "version": "0.1",
781 "tensors": {},
782 }
783 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100784 for idx, shape in enumerate(shapeList):
785
786 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000787 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
788 tens_meta["generator"] = gtu.DataGenType(
789 gtu.DataGenType.FIXED_DATA
790 ).name
791 else:
792 tens_meta["generator"] = gtu.DataGenType(dg_type).name
793
Jeremy Johnson1271c442023-09-05 11:39:26 +0100794 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
795 tens_meta["shape"] = [int(i) for i in shape]
796 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100797 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100798
799 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100800 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100801 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100802 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100803
804 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
805 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000806 if (
807 tens_meta["generator"]
808 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
809 ):
810 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
811 tens_meta["fixed_data_info"] = info
812 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100813 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000814
Won Jeon64e4bfe2024-01-18 06:31:55 +0000815 data_range = None
816 if "data_range_list" in argsDict:
817 data_range = argsDict["data_range_list"][idx]["range"]
818 if "round" in argsDict["data_range_list"][idx]:
819 info["round"] = argsDict["data_range_list"][idx]["round"]
820 elif "data_range" in argsDict:
821 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000822
Won Jeon64e4bfe2024-01-18 06:31:55 +0000823 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100824 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000825 info["range"] = [str(v) for v in data_range]
826 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100827 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
828 info = {}
829 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100830 info["ks"] = int(argsDict["ks"])
831 if "acc_type" in argsDict:
832 # Convert type number into JSON name
833 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
834 "json"
835 ]
836 if "kernel" in argsDict:
837 info["kernel"] = [int(k) for k in argsDict["kernel"]]
838 if "axis" in argsDict:
839 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100840 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000841 elif dg_type == gtu.DataGenType.FULL_RANGE:
842 info = {}
843 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100844 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000845 )
846 tens_meta["full_range_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100847 else:
848 # TODO - other data gen type
849 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850
851 # Using the finished generate config meta data - generate the data if
852 # needed and assign a tensor name from the serializer
853
854 # Need to generate data when not lazy or for the bias tensor as we need
855 # to work out if the bias data is non-zero for compliance
856 if not testGen.args.lazy_data_gen or (
857 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
858 ):
859 # Give this tensor a temporary name until we get one from the serializer
860 temp_name = f"placeholder_{idx}"
861 dg_tens_meta[temp_name] = tens_meta
862 # Create data now using the temporary name to access meta details
863 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000864 if tens_meta["data_type"] == "SHAPE":
865 # Tensor type SHAPE and Numpy file type must be the same
866 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100867 # Remove the item as we will give it the correct name later
868 del dg_tens_meta[temp_name]
869
870 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
871 # The KS value used by compliance verification is altered when the
872 # bias data is non-zero
873 if max(abs(data)) > 0.0:
874 argsDict["ksb"] = argsDict["ks"] + 1
875
876 if testGen.args.lazy_data_gen:
877 data = None
878
879 if tens_meta["input_type"] == "VARIABLE":
880 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
881 else:
882 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
883
884 tens_ser_list.append(tens)
885 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100886 dg_tens_meta[tens.name] = tens_meta
887
Jeremy Johnson1271c442023-09-05 11:39:26 +0100888 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
889
890 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100891 def tvgNegate(
892 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
893 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100894 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000895 # Integer test
896 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100897 pCount, cCount = op["operands"]
898 assert (
899 pCount == 1 and cCount == 0
900 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100901 # Must create tensors with values within accumulator (int32) negatable
902 # range
903 max_val = (1 << 31) - 1
904 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100905 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100906 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000908 tens_ser_list = []
909 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100910 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
911 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000912 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100913 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000914 # ERROR_IF or floating point test
915 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100916 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100917 )
918
Jeremy Johnson30476252023-11-20 16:15:30 +0000919 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000920 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
921 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
922 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
923 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
924 }
925
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100927 def tvgAddSub(
928 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
929 ):
Won Jeon74342e52024-01-09 00:34:40 +0000930 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000931 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100932 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000933 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100934 pCount, cCount = op["operands"]
935 assert (
936 pCount == 2 and cCount == 0
937 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000938 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000939 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
Jeremy Johnson32bf9012024-03-20 16:32:23 +0000940 data_range = None # Use default
941 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE):
942 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100943 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
944 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100945 if add:
946 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
947 else:
948 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
949
950 # Work out the saturation limits
951 max_i32 = (1 << 31) - 1
952 min_i32 = -(1 << 31)
953 max_arr = np.full(shapeList[1], max_i32)
954 min_arr = np.full(shapeList[1], min_i32)
955
956 # Find how much values exceed the maximum/minimums
957 sat_max_arr = np.maximum(res_arr - max_arr, 0)
958 sat_min_arr = np.minimum(res_arr - min_arr, 0)
959
960 if not add:
961 # Swap saturation values and negate values as we need to perform opposite operations
962 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
963
964 # Create new array of unsaturated values by clipping values as needed
965 b_unsat_arr = b_arr
966 if (sat_max_arr != 0).any():
967 # Clip values that cause saturation
968 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
969 # Reduce axes in unsaturated tensor to match original tensor
970 for axis, dim in enumerate(b_arr.shape):
971 if dim != b_unsat_arr.shape[axis]:
972 assert (
973 dim == 1
974 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
975 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
976
977 if (sat_min_arr != 0).any():
978 # Clip values that cause saturation
979 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
980 # Reduce axes in unsaturated tensor to match original tensor
981 for axis, dim in enumerate(b_arr.shape):
982 if dim != b_unsat_arr.shape[axis]:
983 assert (
984 dim == 1
985 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
986 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
987
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000988 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100989 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
990 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000991 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100992 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
993 )
994
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000995 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100996 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000997 # ERROR_IF or floating point test
998 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100999 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001000 )
1001 if data_range:
1002 argsDict["data_range"] = data_range
1003
1004 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001005 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001006 )
1007
1008 @staticmethod
1009 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001010 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001011 ):
1012 if dtypeList[0] in (
1013 DType.INT32,
1014 DType.INT16,
1015 DType.INT8,
1016 ):
1017 # Limit input tensors with cond_if_binary or while_loop to stop
1018 # saturation of add/sub ops with int32 and keep all logical shift
1019 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001020 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 pCount, cCount = op["operands"]
1022 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001023 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001024 for idx, shape in enumerate(shapeList[:]):
1025 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001026 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001028 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001029 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001030 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001031 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1032 )
1033 pRemain -= 1
1034 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001035 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036 testGen.ser.addConst(shape, dtypeList[idx], arr)
1037 )
1038
Jeremy Johnson587cc842024-02-08 11:45:44 +00001039 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001040 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001041 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001042 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001043 )
1044
1045 @staticmethod
1046 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001047 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001048 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001049 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001050 pCount, cCount = op["operands"]
1051 # Force value of operand[1] to be within [0, num_bits]
1052 assert (
1053 pCount == 2 and cCount == 0
1054 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1055
Jeremy Johnson587cc842024-02-08 11:45:44 +00001056 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 for idx, shape in enumerate(shapeList[:]):
1058 if idx == 1:
1059 if dtypeList[idx] == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001060 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001061 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001062 arr = np.int32(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001063 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001064 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001065 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001066 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001067 else:
1068 raise Exception("OpArithmeticRightShift: invalid input dtype")
1069 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001070 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001071 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001072
Jeremy Johnson587cc842024-02-08 11:45:44 +00001073 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001074
1075 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001076 def tvgReshape(
1077 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1078 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001079 dtypeList[1] = DType.SHAPE
1080 shapeList[1] = [len(argsDict["new_shape"])]
1081 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1082 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1083
1084 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001085 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001086 )
1087
1088 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001089 def tvgRescale(
1090 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1091 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001092 scale32 = argsDict["scale"]
1093 multiplier_arr = argsDict["multiplier"]
1094 shift_arr = argsDict["shift"]
1095
1096 if scale32:
1097 dtypeList[1] = DType.INT32
1098 else:
1099 dtypeList[1] = DType.INT16
1100 shapeList[1] = [len(multiplier_arr)]
1101 dtypeList[2] = DType.INT8
1102 shapeList[2] = [len(shift_arr)]
1103 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1104 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1105
1106 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001107 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001108 )
1109
1110 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001111 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001112 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1113 pad_values = argsDict["pad"].flatten()
1114 dtypeList[1] = DType.SHAPE
1115 shapeList[1] = [len(pad_values)]
1116 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1117 argsDict["fixed_data"] = [None, pad_values]
1118
1119 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001120 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001121 )
1122
1123 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001124 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001125 dtypeList[1] = DType.SHAPE
1126 shapeList[1] = [len(argsDict["start"])]
1127 dtypeList[2] = DType.SHAPE
1128 shapeList[2] = [len(argsDict["size"])]
1129 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1130 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1131
1132 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001133 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001134 )
1135
1136 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001137 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001138 dtypeList[1] = DType.SHAPE
1139 shapeList[1] = [len(argsDict["multiples"])]
1140 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1141
1142 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001143 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001144 )
1145
1146 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001147 def tvgSelect(
1148 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1149 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001150 # Set datatype of condition tensor to boolean
1151 dtypeList[0] = DType.BOOL
1152
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001153 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001154 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 )
1156
1157 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001158 def tvgIntDiv(
1159 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1160 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001162 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001163 pCount, cCount = op["operands"]
1164 assert (
1165 pCount == 2 and cCount == 0
1166 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1167
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001168 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001169
1170 # Two invalid cases for Op.INTDIV:
1171 # 1. divisor == 0
1172 # 2. dividend == -(1<<31) and divisor == -1
1173 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001174 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1175 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001176
1177 if (divisor_arr == 0).any():
1178 continue
1179
1180 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1181 continue
1182
1183 break
1184
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001185 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001186 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1187 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001188 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001189 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1190 )
1191
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001192 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001193 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001194 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001195 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001196 )
1197
Jeremy Johnson30476252023-11-20 16:15:30 +00001198 # Set the MUL data range to the square root of the largest value
1199 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001200 TVG_FLOAT_HIGH_VALUE_MUL = {
1201 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1202 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1203 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1204 }
1205
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001206 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001207 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001208 if error_name is not None or dtypeList[0] in (
1209 DType.FP16,
1210 DType.BF16,
1211 DType.FP32,
1212 ):
1213 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001214 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001215 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001216 )
1217 if data_range:
1218 argsDict["data_range"] = data_range
1219
Jeremy Johnson0a042992024-02-28 13:20:05 +00001220 if dtypeList[0] != DType.SHAPE:
1221 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1222 dtypeList[2] = DType.INT8
1223 shapeList[2] = [1]
1224 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1225 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1226
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001227 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001228 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001229 )
1230 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001231 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001232 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001233
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001234 tens_ser_list = []
1235
1236 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001237 if dtypeList[0] == DType.SHAPE:
1238 shift = 0
1239 else:
1240 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001241 if dtypeList[0] == DType.INT8:
1242 num_bits = 8
1243 elif dtypeList[0] == DType.INT16:
1244 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001245 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001246 num_bits = 32
1247 elif error_name == ErrorIf.WrongInputType:
1248 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001249 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001250 raise Exception(
1251 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1252 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001253
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001254 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001255 if dtypeList[idx] == DType.SHAPE:
1256 low = testGen.args.tensor_shape_range[0]
1257 high = testGen.args.tensor_shape_range[1]
1258 else:
1259 low = -(2 ** (num_bits - 1))
1260 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001261
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001262 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1263 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001264
1265 i = 0
1266 while True:
1267
1268 a_arr_64 = a_arr.astype(np.int64)
1269 b_arr_64 = b_arr.astype(np.int64)
1270
1271 if shift > 0:
1272 rounding = 1 << (shift - 1)
1273 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001274 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001275 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001276
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001277 if (result_arr > -(2**31)).all() and (
1278 result_arr <= ((2**31) - 1)
1279 ).all():
1280 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001281
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001282 i = i + 1
1283 a_arr = a_arr // 2
1284 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285
Won Jeon74342e52024-01-09 00:34:40 +00001286 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001287 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001288 tens_ser_list.append(
1289 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1290 )
1291 tens_ser_list.append(
1292 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1293 )
1294 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001295 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001296 tens_ser_list.append(
1297 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1298 )
1299 tens_ser_list.append(
1300 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1301 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001302 tens_ser_list.append(
1303 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1304 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001305
1306 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001307
1308 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001309 def tvgConcat(
1310 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1311 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001312 count = len(shapeList) - testGen.args.num_const_inputs_concat
1313 if count < 1:
1314 count = 1
1315 if testGen.args.num_const_inputs_concat == 0:
1316 count = len(shapeList)
1317
Won Jeon74342e52024-01-09 00:34:40 +00001318 op = testGen.TOSA_OP_LIST[opName]
1319 if op["op"] == Op.CONCAT_SHAPE:
1320 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001321 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001322 else:
1323 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001324 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001325 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001326
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001327 # Override default pCount/cCount for operator
1328 argsDict["p_count"] = count
1329 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001330
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001331 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001332 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001333 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001334
1335 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001336 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001337 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001338 ):
1339 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001340 pCount, cCount = op["operands"]
1341 assert (
1342 pCount == 2 and cCount == 0
1343 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001344 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1345 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001346 tens_ser_list = []
1347 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001348 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1349 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001350 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001351 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1352 )
1353
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001354 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001355
1356 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001357 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001358 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1359 # Integer
1360 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001361 pCount, cCount = op["operands"]
1362 assert (
1363 pCount == 2 and cCount == 0
1364 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001365
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001366 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1367 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001368
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001369 # Using random numbers means that it will be very unlikely that
1370 # there are any matching (equal) values, therefore force that
1371 # there are twice the number of matching values as the tensor rank
1372 for num in range(0, len(shapeList[0]) * 2):
1373 a_index = []
1374 b_index = []
1375 # Choose an index in each axis for the whole shape
1376 for axis in range(0, len(shapeList[0])):
1377 # Index can be up to the largest dimension in both shapes
1378 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001379 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001380 )
1381 # Reduce the index down to a shape's dim for broadcasting
1382 a_index.append(min(shapeList[0][axis] - 1, index))
1383 b_index.append(min(shapeList[1][axis] - 1, index))
1384
1385 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1386
Jeremy Johnsona0150012023-11-15 15:52:06 +00001387 tens_ser_list = []
1388 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001389 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1390 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001391 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001392 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1393 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001394 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001395 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001396 # ERROR_IF or floating point test
1397 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001398 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001399 )
1400
1401 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001402 def tvgReduceSum(
1403 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1404 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001405 dtype = dtypeList[0]
1406 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001407 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001408 pCount, cCount = op["operands"]
1409 assert (
1410 pCount == 1 and cCount == 0
1411 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1412 # Limit values so that the sum cannot exceed the range of an int32 during
1413 # summation of any axis
1414 range_val = int((1 << 31) / max(shapeList[0]))
1415 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001416 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001417 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001418 tens_ser_list = []
1419 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001420 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001421 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001422 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001423 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001424 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001425 if (
1426 error_name is None
1427 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1428 ):
1429 # Limit ranges for (non error & non compliance) tests by using
1430 # values that can be summed on any axis to not hit infinity
1431 highval_lookup = {
1432 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1433 / max(shapeList[0])
1434 }
1435 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001436 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001437 )
1438 assert data_range is not None
1439 argsDict["data_range"] = data_range
1440
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001441 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001442 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001443 )
1444
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001445 @staticmethod
1446 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001447 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001448 ):
1449 dtype = dtypeList[0]
1450 if error_name is None:
1451 # Limit ranges for (non error) tests by using
1452 # values that can be multiplied on any axis to not hit infinity
1453 highval_lookup = {
1454 dtype: math.pow(
1455 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1456 1 / max(shapeList[0]),
1457 )
1458 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001459 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001460 assert data_range is not None
1461 argsDict["data_range"] = data_range
1462
1463 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001464 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001465 )
1466
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001467 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001468 def tvgResize(
1469 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1470 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001471 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001472 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001473 dtypeList[0],
1474 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1475 )
1476 if data_range:
1477 argsDict["data_range"] = data_range
1478 # Needed for compliance
1479 argsDict["max_abs_value"] = data_range[1]
1480
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001481 scale_values = argsDict["scale"]
1482 offset_values = argsDict["offset"]
1483 border_values = argsDict["border"]
1484 dtypeList[1] = DType.SHAPE
1485 dtypeList[2] = DType.SHAPE
1486 dtypeList[3] = DType.SHAPE
1487 shapeList[1] = [len(scale_values)]
1488 shapeList[2] = [len(offset_values)]
1489 shapeList[3] = [len(border_values)]
1490 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1491
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001492 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001493 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001494 )
1495
Jeremy Johnson30476252023-11-20 16:15:30 +00001496 # Set the POW exponent high data range
1497 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1498 DType.FP32: 10.0,
1499 DType.FP16: 10.0,
1500 DType.BF16: 10.0,
1501 }
1502 # POW highest base value (within a safe margin of error) that can be raised
1503 # to +ve exponent that doesn't become Infinity
1504 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1505 DType.FP32: math.floor(
1506 math.pow(
1507 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1508 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1509 )
1510 ),
1511 DType.FP16: math.floor(
1512 math.pow(
1513 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1514 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1515 )
1516 ),
1517 DType.BF16: math.floor(
1518 math.pow(
1519 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1520 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1521 )
1522 ),
1523 }
1524 # POW lowest base value (within a safe margin of error) that can be raised
1525 # to -ve exponent that doesn't become Infinity
1526 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1527 DType.FP32: math.ceil(
1528 math.pow(
1529 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1530 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1531 )
1532 * 1000
1533 )
1534 / 1000,
1535 DType.FP16: math.ceil(
1536 math.pow(
1537 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1538 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1539 )
1540 * 1000
1541 )
1542 / 1000,
1543 DType.BF16: math.ceil(
1544 math.pow(
1545 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1546 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1547 )
1548 * 1000
1549 )
1550 / 1000,
1551 }
1552
1553 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001554 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001555 if error_name is not None:
1556 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001557 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001558 )
1559 dtype = dtypeList[0]
1560 # Different ranges for POW
1561 test_set = argsDict["s"]
1562 if test_set == 0:
1563 # Positive base with fractional exponent
1564 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001565 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001566 dtype,
1567 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1568 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1569 )
1570 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001571 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001572 )
1573 exp_round = False
1574 else:
1575 # Integer exponent
1576 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001577 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001578 )
1579 exp_round = True
1580 if test_set == 1:
1581 # Positive base
1582 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001583 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001584 dtype,
1585 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1586 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1587 )
1588 else:
1589 assert test_set == 2
1590 # Negative base
1591 # Supply new look up tables with negative values
1592 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001593 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001594 dtype,
1595 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1596 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1597 )
1598
1599 data_range_list = (
1600 {
1601 "range": base_range,
1602 },
1603 {
1604 "range": exp_range,
1605 "round": exp_round,
1606 },
1607 )
1608 argsDict["data_range_list"] = data_range_list
1609 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001610 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001611 )
1612
1613 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001614 def tvgLogRsqrt(
1615 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1616 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001617 # LOG & RSQRT data range from lowest expressible positive number to
1618 # largest to avoid NaNs
1619 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001620 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001621 dtypeList[0],
1622 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1623 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1624 )
1625 if data_range:
1626 argsDict["data_range"] = data_range
1627
1628 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001629 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001630 )
1631
1632 # Set the EXP data range to the log of the largest to smallest values
1633 # to avoid infinities or making the result zero
1634 TVG_FLOAT_HIGH_VALUE_EXP = {
1635 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1636 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1637 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1638 }
1639 TVG_FLOAT_LOW_VALUE_EXP = {
1640 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1641 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1642 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1643 }
1644
1645 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001646 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001647 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001648 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001649 dtypeList[0],
1650 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1651 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1652 )
1653 if data_range:
1654 argsDict["data_range"] = data_range
1655
1656 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001657 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001658 )
1659
1660 @staticmethod
1661 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001662 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001663 ):
1664 dtype = dtypeList[0]
1665 if (
1666 error_name is None
1667 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001668 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001669 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001670 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001671 # Limit ranges for (non error & non compliance) FP tests by using
1672 # values that can be multiplied on any axis to not hit infinity/NaN
1673 IC = shapeList[0][1]
1674 highval_lookup = {
1675 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1676 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001677 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001678 assert data_range is not None
1679 argsDict["data_range"] = data_range
1680
1681 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001682 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001683 )
1684
Jeremy Johnson708da822023-11-15 16:25:45 +00001685 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001686 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001687 in_dtype = dtypeList[0]
1688 out_dtype = argsDict["out_type"]
1689 # Create look up to limit input tensor to output type maximums to avoid
1690 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001691 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001692 highval_lookup = {in_dtype: out_range[1]}
1693 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001694 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001695 in_dtype,
1696 highval_lookup,
1697 )
1698
1699 assert data_range is not None
1700 argsDict["data_range"] = data_range
1701
1702 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001703 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001704 )
1705
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001706 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001707 def tvgGather(
1708 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1709 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001710 K = shapeList[0][1]
1711
1712 # Fix the type of the indices tensor
1713 dtypeList[1] = DType.INT32
1714
1715 dtype = dtypeList[0]
1716 if not gtu.dtypeIsSupportedByCompliance(dtype):
1717 # Test unsupported by data generator
1718 op = testGen.TOSA_OP_LIST[opName]
1719 pCount, cCount = op["operands"]
1720 assert (
1721 pCount == 2 and cCount == 0
1722 ), "Op.GATHER must have 2 placeholders, 0 consts"
1723
1724 tens_ser_list = []
1725 for idx, shape in enumerate(shapeList):
1726 dtype = dtypeList[idx]
1727 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001728 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001729 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1730 else:
1731 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001732 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001733 # To match old functionality - create indices as CONST
1734 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1735
1736 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1737
1738 else:
1739 # ERROR_IF or floating point test
1740 # Use inclusive values upto index K for indices tensor
1741 data_range_list = (
1742 {"range": None},
1743 {"range": (0, K - 1)},
1744 )
1745 argsDict["data_range_list"] = data_range_list
1746
1747 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001748 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001749 )
1750
1751 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001752 def tvgScatter(
1753 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1754 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001755 K = shapeList[0][1]
1756 W = shapeList[2][1]
1757
1758 # Work out an indices tensor here with data that doesn't exceed the
1759 # dimension K of the values_in tensor and does NOT repeat the same K
1760 # location as needed by the spec:
1761 # "It is not permitted to repeat the same output index within a single
1762 # SCATTER operation and so each output index occurs at most once."
1763 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1764
1765 # Fix the type of the indices tensor
1766 dtypeList[1] = DType.INT32
1767
1768 dtype = dtypeList[0]
1769 if not gtu.dtypeIsSupportedByCompliance(dtype):
1770 # Test unsupported by data generator
1771 op = testGen.TOSA_OP_LIST[opName]
1772 pCount, cCount = op["operands"]
1773 assert (
1774 pCount == 3 and cCount == 0
1775 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1776
1777 tens_ser_list = []
1778 for idx, shape in enumerate(shapeList):
1779 dtype = dtypeList[idx]
1780 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001781 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001782 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1783 else:
1784 # Create the indices array
1785 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1786 arr = []
1787 for n in range(shape[0]):
1788 # Get a shuffled list of output indices (0 to K-1) and
1789 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001790 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001791 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1792 # To match old functionality - create indices as CONST
1793 tens_ser_list.append(
1794 testGen.ser.addConst(shape, dtype, indices_arr)
1795 )
1796
1797 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1798
1799 else:
1800 # ERROR_IF or floating point test
1801 # Use inclusive values upto index K for indices tensor
1802 data_range_list = (
1803 {"range": None},
1804 {"range": (0, K - 1)},
1805 {"range": None},
1806 )
1807 argsDict["data_range_list"] = data_range_list
1808
1809 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001810 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001811 )
1812
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001813
1814class TosaArgGen:
1815 """Argument generators create exhaustive or random lists of attributes for
1816 operators that take attributes or other parameters.
1817
1818 The return value is a list of (descriptive_name, [arglist]) tuples where
1819 the descriptive_name is appended to the test name and the arglist is expanded
1820 as arguments to the operator build function.
1821 """
1822
1823 def __init__(self):
1824 pass
1825
1826 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001827 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001828 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001829 if (
1830 error_name is None
1831 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1832 and gtu.dtypeIsSupportedByCompliance(dtype)
1833 ):
Tai Ly60dc48c2024-03-08 22:19:41 +00001834 if gtu.dtypeIsFloat(dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001835 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1836 else:
1837 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1838 else:
1839 # Error test or No data generator types listed - assume random
1840 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1841
1842 # Expand arg list with other data generator types
1843 new_arg_list = []
1844 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001845 for arg_str, args_dict in arg_list:
evacha019c96eef2024-02-07 11:21:55 +00001846
1847 if dg_type == gtu.DataGenType.FULL_RANGE:
1848 tensor_size = gtu.product(shapeList[0])
1849 if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1850 # Large enough tensor data size for full range, add a single test
1851 num_test_sets = 0
1852 else:
1853 # Not enough data size for full range of values, revert to random numbers
1854 dg_type = gtu.DataGenType.PSEUDO_RANDOM
1855
Jeremy Johnson1271c442023-09-05 11:39:26 +01001856 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001857 if error_name is None:
1858 num_test_sets = (
1859 args_dict["num_test_sets"]
1860 if "num_test_sets" in args_dict
1861 else 0
1862 )
1863 else:
evacha019c96eef2024-02-07 11:21:55 +00001864 # Add single test for pseudo random
Jeremy Johnson30476252023-11-20 16:15:30 +00001865 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001866
1867 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1868 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001869 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001870 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001871 shape_info = (
1872 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1873 if "shape" in args_dict
1874 else ""
1875 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001876 logger.info(
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00001877 f"Skipping {opName}{shape_info} {gtu.DTYPE_ATTRIBUTES[dtype]['json']} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001878 )
1879 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001880 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001881 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001882 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001883
Jeremy Johnson30476252023-11-20 16:15:30 +00001884 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1885
1886 if num_test_sets > 0:
1887 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001888 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
1889 set_args_dict = args_dict.copy()
1890 set_args_dict["s"] = s
1891 set_args_dict["dg_type"] = dg_type
1892 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001893 else:
1894 # Default is a single test
evacha019c96eef2024-02-07 11:21:55 +00001895 new_args_dict = args_dict.copy()
1896 new_args_dict["dg_type"] = dg_type
1897 new_arg_list.append((arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001898
1899 return new_arg_list
1900
1901 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001902 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001903 """A trivial argument generator for operators that don't take any
1904 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001905 arg_list = TosaArgGen._add_data_generators(
1906 testGen,
1907 opName,
evacha019c96eef2024-02-07 11:21:55 +00001908 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001909 dtype,
1910 [("", {})],
1911 error_name,
1912 )
1913 # Return list of tuples: (arg_str, args_dict)
1914 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001915
1916 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001917 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001918 """Pow operator needs different test sets to cover random numbers
1919 without creating NaNs or Infs"""
1920 arg_list = TosaArgGen._add_data_generators(
1921 testGen,
1922 opName,
evacha019c96eef2024-02-07 11:21:55 +00001923 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001924 dtype,
1925 [("", {"num_test_sets": 3})],
1926 error_name,
1927 )
1928 # Return list of tuples: (arg_str, args_dict)
1929 return arg_list
1930
1931 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001932 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001933 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001934 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001935 shape = shapeList[0]
1936
1937 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001938 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001939 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001940 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001941 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001942 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001943 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001944 # Create tests for each dimension
1945 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001946
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001947 opid = testGen.TOSA_OP_LIST[opName]["op"]
1948
1949 for a in axes:
1950 args_dict = {"axis": int(a)}
1951 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001952 output_shape = shape.copy()
1953 if error_name is None:
1954 # It only matters that we calculate the dot_products correctly
1955 # for non error_if tests as they should never be run
1956 output_shape[a] = 1
1957 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001958 args_dict["shape"] = shape
1959 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1960 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1961
1962 arg_list.append(("axis{}".format(a), args_dict))
1963
1964 arg_list = TosaArgGen._add_data_generators(
1965 testGen,
1966 opName,
evacha019c96eef2024-02-07 11:21:55 +00001967 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001968 dtype,
1969 arg_list,
1970 error_name,
1971 )
1972 # Return list of tuples: (arg_str, args_dict)
1973 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001974
1975 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001976 def _calculate_sparsity(num_tests, sparsity_factor):
1977 sparsity = num_tests // sparsity_factor + 1
1978 # If there are only a small number of tests, just select them all
1979 if sparsity < 13:
1980 sparsity = 1
1981 # To get a variety of parameter combinations sparsity should not be a
1982 # multiple of 2, 3 or 5
1983 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1984 sparsity += 1
1985 return sparsity
1986
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001987 # Maximum number of error_if variants to produce
Jeremy Johnson87460262024-03-25 09:46:02 +00001988 MAX_TESTS_ERROR_IFS = 3
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001989
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001990 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001991 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001992 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001993 arg_list = []
1994
Jeremy Johnson0c716862023-04-13 17:18:19 +01001995 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001996 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001997 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001998 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001999
Tai Lyf36f2562024-03-14 16:21:29 +00002000 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2001
2002 if error_name == ErrorIf.WrongAccumulatorType:
2003 accum_dtypes = (
2004 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2005 )
James Ward8b390432022-08-12 20:48:56 +01002006
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002007 # For op type checks
2008 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002009
2010 # Check the rank
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002011 rank = 5 if op["op"] == Op.CONV3D else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002012 if error_name != ErrorIf.WrongRank:
2013 assert len(ifm_shape) == rank
2014 assert len(filter_shape) == rank
2015
Jeremy Johnson0c716862023-04-13 17:18:19 +01002016 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002017 k_rank = rank - 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002018 k_pos = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002019 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002020 # compliance size - KS
2021 k_size = gtu.product(k_shape)
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002022 if not op["op"] == Op.DEPTHWISE_CONV2D:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002023 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002024
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002025 def get_conv_output_info(p, s, d, fix_up_padding=False):
2026 # Work out remainders and output dimensions with an
2027 # option to adjust paddings to create a valid operation
2028 nonlocal ifm_shape, k_shape, error_name, k_rank
2029 if fix_up_padding:
2030 p = list(p) # Make paddings editable
2031 outputs_no_stride = []
2032 remainders = []
2033 outputs = []
2034 for index in range(k_rank):
2035 pad_offset = index * 2
2036 fixed = False
2037 # Fix up pad values to produce valid conv2d
2038 while not fixed:
2039 # Output dimension without being adjusted for stride
2040 output_no_stride = (
2041 ifm_shape[index + 1]
2042 - 1
2043 + p[pad_offset]
2044 + p[pad_offset + 1]
2045 - (k_shape[index] - 1) * d[index]
2046 )
2047 # Tensor left over after applying striding
2048 remainder = output_no_stride % s[index]
2049 if not fix_up_padding:
2050 # Just want remainders and outputs
2051 break
2052 if output_no_stride <= 0:
2053 p[pad_offset + 1] += abs(output_no_stride) + 1
2054 continue
2055 if error_name == ErrorIf.ConvOutputShapeNonInteger:
2056 if remainder:
2057 # Conditions to trigger the test
2058 fixed = True
2059 else:
2060 p[pad_offset + 1] += 1
2061 else:
2062 if remainder:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002063 # Stride will be negative for StrideSmallerOne
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002064 assert remainder > 0 or (
2065 error_name == ErrorIf.StrideSmallerOne and remainder < 0
2066 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002067 p[pad_offset + 1] += abs(remainder)
2068 else:
2069 fixed = True
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002070 outputs_no_stride.append(output_no_stride)
2071 remainders.append(remainder)
2072 # Output dimension taking in to account stride
2073 outputs.append((output_no_stride // s[index]) + 1)
2074
2075 if fix_up_padding:
2076 p = tuple(p) # Make the paddings read-only
2077 assert min(outputs_no_stride) > 0, "Fix up did not work!"
2078 return p, remainders, outputs, outputs_no_stride
2079
2080 # Only fix up padding for conv2d and float types currently
2081 fix_up_padding = gtu.dtypeIsFloat(dtypes[0]) and op["op"] == Op.CONV2D
2082 # Allow any size of output dimension
2083 max_dim_size = None
2084 # Include all tests by default
2085 sparsity = 1
2086
2087 # Work out padding, strides and dilation ranges depending on
2088 # error and arguments
2089 if error_name in (
2090 ErrorIf.PadSmallerZero,
2091 ErrorIf.StrideSmallerOne,
2092 ErrorIf.DilationSmallerOne,
2093 ):
2094 # Use specific invalid value(s)
2095 if error_name == ErrorIf.PadSmallerZero:
2096 # Create negative paddings but with positive opposite paddings
2097 neg_pad = rng.choice(range(-5, 0))
2098 p_vals = [neg_pad, abs(neg_pad)]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002099 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002100 p_vals = [0, 0]
2101 if error_name == ErrorIf.StrideSmallerOne:
2102 # Can't use stride=0, as it is used to derive output shape, as a divisor
2103 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002104 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002105 s_vals = [1]
2106 if error_name == ErrorIf.DilationSmallerOne:
2107 d_vals = [rng.choice(range(-5, 1))]
2108 else:
2109 d_vals = [1]
2110 paddings = {tuple(p_vals) * k_rank}
2111 strides = {tuple(s_vals) * k_rank}
2112 dilations = {tuple(d_vals) * k_rank}
2113
2114 fix_up_padding = True # Need to fix up paddings to be valid
2115
2116 elif testGen.args.level8k and error_name is None:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002117 # Only test 8k levels boundaries
2118 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2119 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2120 bigPadding = bigKernel
2121
2122 dilation_shape = [1] * k_rank
2123 pad_shape = [0] * k_rank * 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002124 if op["op"] == Op.CONV3D:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002125 # Small stride apart from for big kernel (see below) to keep
2126 # tensor size/calculation small
2127 stride_shape = [1] * k_rank
2128 for idx in range(k_rank):
2129 pad_offset = idx * 2
2130 if k_shape[idx] == bigKernel:
2131 # Padding shape needs to account for tensor shape
2132 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2133 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2134 # Big stride to reduce output size
2135 stride_shape[idx] = bigKernel
2136 else:
2137 # Account for kernel size
2138 pad_shape[pad_offset] = k_shape[idx] - 1
2139 else:
2140 # Always have a large stride with extra padding and dilation to keep
2141 # tensor calculation reasonable
2142 stride_shape = [bigKernel] * k_rank
2143 for idx in range(k_rank):
2144 # Dilation shape must account for kernel size
2145 dilation_shape[idx] = bigKernel // k_shape[idx]
2146 # Padding shape needs to accommodate tensor/kernel & dilation
2147 pad_offset = idx * 2
2148 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2149 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2150
2151 strides = {tuple(stride_shape)}
2152 dilations = {tuple(dilation_shape)}
2153 paddings = {tuple(pad_shape)}
2154 # Create a limit for the output dimensions size
2155 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2156
2157 # Currently allow all combinations that are reasonable size
2158 sparsity = 1
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002159 else:
2160 # Generate comprehensive argument lists
2161 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
2162 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
2163 # Stride must be greater than 1 to force non-integer error
2164 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
2165 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
2166 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002167
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002168 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2169 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
2170
2171 if error_name is None and testGen.args.oversize:
2172 # add some oversize argument values
2173 if max(ifm_shape) < 64:
2174 bigPadding = 9
2175 paddings.update(
2176 {
2177 x
2178 for x in itertools.product(
2179 *([[0, bigPadding]] * (k_rank * 2))
2180 )
2181 }
2182 )
2183 bigStride = 8
2184 strides.update(
2185 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2186 )
2187 bigDilation = 7
2188 dilations.update(
2189 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2190 )
2191
2192 if error_name is None:
2193 # There are too many parameter combinations, so generate them sparsely,
2194 sparsity_factor = 120
2195 sparsity = TosaArgGen._calculate_sparsity(
2196 len(paddings) * len(strides) * len(dilations), sparsity_factor
2197 )
2198
2199 # Run through all the argument options creating valid test cases
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002200 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002201 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002202 for a in accum_dtypes:
2203 for s in sorted(list(strides)):
2204 for p in sorted(list(paddings)):
2205 for d in sorted(list(dilations)):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002206 if more_tests and (n % sparsity == 0):
2207 (
2208 p,
2209 remainders,
2210 outputs,
2211 outputs_no_stride,
2212 ) = get_conv_output_info(p, s, d, fix_up_padding)
2213 # Following is like checking each dimension N:
2214 # (ifm_shape[N+1] - 1 + p[N*2] + p[N*2+1]) > d[N] * (k_shape[N] - 1)
2215 if min(outputs_no_stride) <= 0:
2216 # Not a valid operation
2217 n += 1 # Increment count of tests
2218 continue
Tai Lyf36f2562024-03-14 16:21:29 +00002219
2220 if (
2221 # the parameters must produce integer exact output
2222 error_name != ErrorIf.ConvOutputShapeNonInteger
2223 and max(remainders) == 0
2224 ) or (
2225 error_name == ErrorIf.ConvOutputShapeNonInteger
2226 and max(remainders) > 0
2227 ):
2228 if (
2229 max_dim_size is not None
2230 and max(outputs) >= max_dim_size
2231 ):
2232 # Test will consume too much memory - skip it
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002233 logger.debug(
2234 "agConv: Convolution output too big - skipped"
2235 )
Tai Lyf36f2562024-03-14 16:21:29 +00002236 continue
2237
2238 # Compliance - number of dot product calculations
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002239 if op["op"] == Op.DEPTHWISE_CONV2D:
Tai Lyf36f2562024-03-14 16:21:29 +00002240 # N*OH*OW*C*M
2241 dots = gtu.product(
2242 (ifm_shape[0], *outputs, *filter_shape[2:])
2243 )
2244 else:
2245 # N*OH*OW*OC or N*OD*OH*OW*OC
2246 dots = gtu.product(
2247 (ifm_shape[0], *outputs, filter_shape[0])
2248 )
2249 args_dict = {
2250 "acc_type": a,
2251 "stride": s,
2252 "pad": p,
2253 "dilation": d,
2254 "kernel": k_shape,
2255 "ks": k_size,
2256 "dot_products": dots,
2257 "shape": ifm_shape,
2258 }
2259
2260 # Support for larger values than 9 needs different delimiter
2261 delim = "" if max(s + p + d) <= 9 else "x"
2262 arg_list.append(
2263 (
2264 "acc{}_st{}_pad{}_dilat{}".format(
2265 testGen.typeStr(a),
2266 delim.join([str(x) for x in s]),
2267 delim.join([str(x) for x in p]),
2268 delim.join([str(x) for x in d]),
2269 ),
2270 args_dict,
2271 )
2272 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002273 if (
2274 error_name
Jeremy Johnson87460262024-03-25 09:46:02 +00002275 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002276 ):
2277 # Found enough errors
2278 logger.debug(
2279 f"Skipping creating more conv error tests for {error_name}"
2280 )
2281 more_tests = False
Tai Lyf36f2562024-03-14 16:21:29 +00002282 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002283
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002284 arg_list = TosaArgGen._add_data_generators(
2285 testGen,
2286 opName,
evacha019c96eef2024-02-07 11:21:55 +00002287 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002288 dtypes[0],
2289 arg_list,
2290 error_name,
2291 )
2292 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002293 return arg_list
2294
2295 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002296 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002297
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002298 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002299 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002300
2301 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002302 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002303 elif error_name == ErrorIf.WrongInputType:
2304 # Pick some potentially correct output dtype if input type is incorrect
2305 accum_dtype = DType.INT32
2306 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002307 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002308
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002309 # Set up compliance info
2310 args_dict = {
2311 "acc_type": accum_dtype,
2312 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2313 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2314 "shape": shapeList[0],
2315 }
2316
2317 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2318
2319 arg_list = TosaArgGen._add_data_generators(
2320 testGen,
2321 opName,
evacha019c96eef2024-02-07 11:21:55 +00002322 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002323 input_dtype,
2324 arg_list,
2325 error_name,
2326 )
2327 # Return list of tuples: (arg_str, args_dict)
2328 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002329
2330 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002331 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002332 # Get valid accumulate type(s)
2333 if dtype == DType.INT8:
2334 accum_dtypes = [DType.INT32]
2335 elif dtype == DType.INT16:
2336 accum_dtypes = [DType.INT48]
2337 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002338 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002339 elif dtype == DType.BF16:
2340 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002341 elif dtype == DType.FP32:
2342 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002343 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2344 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002345 elif error_name is None:
2346 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2347
2348 if error_name == ErrorIf.WrongOutputType:
2349 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002350 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002351 elif error_name == ErrorIf.WrongInputType:
2352 # Pick some potentially correct output dtype if input type is incorrect
2353 accum_dtypes = [DType.INT32]
2354
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002355 # Set up compliance info
2356 args_dict = {
2357 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2358 # Set dot_products = N*H*W
2359 "dot_products": gtu.product(
2360 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2361 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002362 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002363 }
2364
2365 # Create arg tuple of string and dict
2366 arg_list = []
2367 for a in accum_dtypes:
2368 d = args_dict.copy()
2369 d["acc_type"] = a
2370 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002371
2372 arg_list = TosaArgGen._add_data_generators(
2373 testGen,
2374 opName,
evacha019c96eef2024-02-07 11:21:55 +00002375 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002376 dtype,
2377 arg_list,
2378 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002379 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002380 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002381 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002382
2383 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002384 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002385 arg_list = []
2386
Jeremy Johnson0c716862023-04-13 17:18:19 +01002387 if testGen.args.level8k and error_name is not None:
2388 # Don't produce negative large tests
2389 return arg_list
2390
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002391 ifm_shape = shapeList[0]
2392 filter_shape = shapeList[1]
2393
Tai Lyf36f2562024-03-14 16:21:29 +00002394 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2395
2396 if error_name == ErrorIf.WrongAccumulatorType:
2397 accum_dtypes = (
2398 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2399 )
James Ward8b390432022-08-12 20:48:56 +01002400
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002401 # Must be rank 4
2402 if error_name != ErrorIf.WrongRank:
2403 assert len(ifm_shape) == 4
2404 assert len(filter_shape) == 4
2405
Jeremy Johnson0c716862023-04-13 17:18:19 +01002406 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002407 # compliance size - KS
2408 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002409
Jeremy Johnson0c716862023-04-13 17:18:19 +01002410 if not testGen.args.level8k:
2411 # Generate comprehensive argument lists
2412 # - except for named errors, which use specific invalid value(s)
2413 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2414 if error_name == ErrorIf.PadLargerEqualKernel:
2415 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002416 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002417 else:
2418 p_vals = [
2419 x
2420 for x in range(
2421 smallest_padding_size, testGen.args.max_conv_padding + 1
2422 )
2423 ]
2424 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2425 if error_name == ErrorIf.StrideSmallerOne:
2426 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002427 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002428 else:
2429 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2430 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002431
Jeremy Johnson0c716862023-04-13 17:18:19 +01002432 if not error_name and testGen.args.oversize:
2433 # add some oversize argument values
2434 if max(ifm_shape) < 64:
2435 bigPadding = 9
2436 paddings.update(
2437 {
2438 x
2439 for x in itertools.product(
2440 *([[smallest_padding_size, bigPadding]] * 4)
2441 )
2442 }
2443 )
2444 bigStride = 8
2445 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2446
2447 # There are too many parameter combinations, so generate them sparsely,
2448 # very sparse for negative tests
2449 sparsity_factor = 2 if error_name else 10
2450 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2451 # If there are only a small number of tests, just select them all
2452 if sparsity < 13:
2453 sparsity = 1
2454 # To get a variety of parameter combinations sparsity should not be a
2455 # multiple of 2, 3 or 5
2456 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2457 sparsity += 1
2458 else:
2459 # Only test 8k levels boundaries
2460 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2461 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2462 bigPadding = bigKernel
2463
2464 pad_shape = [0] * (len(k_shape) * 2)
2465 stride_shape = [1] * len(k_shape)
2466 # The point at which input dimension combined with the stride will
2467 # create large output sizes!
2468 LARGE_SIZE = 2
2469 for idx in range(len(k_shape)):
2470 pad_offset = idx * 2
2471 if k_shape[idx] == bigKernel:
2472 # Set large stride
2473 stride_shape[idx] = bigKernel
2474 # Use negative output padding to reduce shape size
2475 pad_shape[pad_offset] = -(bigPadding - 1)
2476 if ifm_shape[idx + 1] > LARGE_SIZE:
2477 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2478 else:
2479 # The other dimension should be the bigKernel
2480 alt_idx = 1 - idx
2481 if (
2482 k_shape[alt_idx] == bigKernel
2483 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2484 ):
2485 # As the input is small, the large stride won't
2486 # affect the output so we can add some padding
2487 pad_shape[pad_offset + 1] = bigPadding
2488
2489 strides = {tuple(stride_shape)}
2490 paddings = {tuple(pad_shape)}
2491
2492 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002493 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002494
2495 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002496 for a in accum_dtypes:
2497 for s in sorted(list(strides)):
2498 for p in sorted(list(paddings)):
2499 if n % sparsity == 0:
2500 # Determine the output shape
2501 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2502 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2503 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002504
Tai Lyf36f2562024-03-14 16:21:29 +00002505 # N*OH*OW*OC
2506 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2507 args_dict = {
2508 "acc_type": a,
2509 "stride": s,
2510 "pad": p,
2511 "kernel": k_shape,
2512 "ks": k_size,
2513 "dot_products": dots,
2514 "shape": ifm_shape,
2515 "out_shape": os,
2516 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002517
Tai Lyf36f2562024-03-14 16:21:29 +00002518 # Support for larger values than 9 needs different delimiter
2519 delim = "" if max(s + p) <= 9 else "x"
2520 arg_list.append(
2521 (
2522 "acc{}_st{}_pad{}_os{}".format(
2523 testGen.typeStr(a),
2524 delim.join([str(x) for x in s]),
2525 delim.join([str(x) for x in p]),
2526 "x".join([str(x) for x in os]),
2527 ),
2528 args_dict,
2529 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002530 )
Tai Lyf36f2562024-03-14 16:21:29 +00002531 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002532
Jeremy Johnson95a67102024-01-10 14:16:39 +00002533 arg_list = TosaArgGen._add_data_generators(
2534 testGen,
2535 opName,
evacha019c96eef2024-02-07 11:21:55 +00002536 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002537 dtypes[0],
2538 arg_list,
2539 error_name,
2540 )
2541 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542 return arg_list
2543
2544 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002545 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002546 rank = len(shapeList[0])
2547
2548 # Exhaustively test combinations of padding on each side of each dimension
2549 # - the range of padding values is defined by pad_min and pad_max
2550 # - for padding >9, the name format needs to be more distinctive
2551 pad_min, pad_max = 0, 1
2552 pad_values = [x for x in range(pad_min, pad_max + 1)]
2553 if error_name == ErrorIf.PadSmallerZero:
2554 pad_values = [x for x in range(-2, 0)]
2555 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2556 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2557
2558 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002559 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002560 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002561 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002562 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002563 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002564 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002565 assert error_name == ErrorIf.WrongInputType
2566 pad_const_int = 0
2567 pad_const_fp = 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002568
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002569 list_shape_pad_values = list(shape_pad_values)
2570 # If we are producing tests for rank 6 or greater use sparsity
2571 if len(list_shape_pad_values) > 1024:
2572 sparsity_factor = 2 if error_name else 120
2573 sparsity = TosaArgGen._calculate_sparsity(
2574 len(list_shape_pad_values), sparsity_factor
2575 )
2576 else:
2577 sparsity = 1
2578
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002579 # Build arg list
2580 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002581 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002582 paddings = list(paddings)
2583 args_valid = True
2584
2585 if error_name == ErrorIf.PadSmallerZero:
2586 # Prevent negative output shapes while ensuring still testing for negative padding
2587 for i in range(rank):
2588 dim_after_padding = (
2589 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2590 )
2591 if dim_after_padding < 1:
2592 paddings[i] = (0, 0)
2593 if all([p > -1 for p in paddings[i]]):
2594 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002595 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002596 name = "pad"
2597 for r in range(rank):
2598 before, after = paddings[r]
2599 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002600 args_dict = {
2601 "pad": np.array(paddings),
2602 "pad_const_int": pad_const_int,
2603 "pad_const_fp": pad_const_fp,
2604 }
2605 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002606
2607 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002608 logger.debug(
2609 f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
2610 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002611
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002612 arg_list = TosaArgGen._add_data_generators(
2613 testGen,
2614 opName,
evacha019c96eef2024-02-07 11:21:55 +00002615 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002616 dtype,
2617 arg_list,
2618 error_name,
2619 )
2620
2621 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002622 return arg_list
2623
2624 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002625 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002626 arg_list = []
2627
2628 shape = shapeList[0]
2629 if error_name != ErrorIf.WrongRank:
2630 assert len(shape) == 4
2631
Jeremy Johnson0c716862023-04-13 17:18:19 +01002632 test_level8k = testGen.args.level8k and error_name is None
2633
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002634 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002635 startKernel = 2
2636 startPad = 0
2637 if not test_level8k:
2638 # Generate comprehensive argument lists
2639 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2640 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2641 # Stride must be greater than 1 to force non-integer error
2642 s_vals = [
2643 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2644 ]
2645 strides = {x for x in itertools.product(*([s_vals] * 2))}
2646 k_vals = [
2647 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2648 ]
2649 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2650 max_dim_size = None
2651 else:
2652 # Only test 8k levels
2653 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2654 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2655 strides = {(1, bigStride), (bigStride, 4)}
2656 kernels = {(1, bigKernel), (bigKernel, 3)}
2657 paddings = set()
2658 for s in sorted(list(strides)):
2659 for k in sorted(list(kernels)):
2660 padding = []
2661 for idx in range(len(k)):
2662 total_padding = s[idx] - shape[idx + 1] + k[idx]
2663 while total_padding < 0:
2664 # Must meet: shape + padding > kernel
2665 total_padding += s[idx]
2666 if total_padding < k[idx]:
2667 padding.extend([0, total_padding])
2668 else:
2669 # Note this may produce padding >= k[idx] which is not
2670 # allowed - but will be ignored in the creation loop below
2671 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2672 paddings.add(tuple(padding))
2673 # Create a limit for the output dimensions size
2674 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002675
James Ward8b390432022-08-12 20:48:56 +01002676 if opName == "max_pool2d":
2677 accum_dtypes = [None] # max_pool has no accumulate dtype
2678 elif dtype == DType.INT8 or dtype == DType.INT16:
2679 accum_dtypes = [DType.INT32]
2680 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002681 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002682 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002683 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002684 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2685 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002686 elif error_name is None:
2687 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2688 else:
2689 # Set to something for the ErrorIf case which has
2690 # incorrect input data-type
2691 accum_dtypes = [DType.INT32]
2692
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002693 if error_name == ErrorIf.WrongAccumulatorType:
2694 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2695
Jeremy Johnson0c716862023-04-13 17:18:19 +01002696 if not test_level8k:
2697 if testGen.args.oversize:
2698 # add some oversize argument values
2699 bigStride = 7
2700 bigKernel = 9
2701 strides.update(
2702 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002703 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002704 kernels.update(
2705 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2706 )
2707 if max(shape) < 64:
2708 # padding must be less than the kernel size
2709 bigPadding = bigKernel - 1
2710 paddings.update(
2711 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2712 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002713
Jeremy Johnson87460262024-03-25 09:46:02 +00002714 if error_name:
2715 # Cycle through all error_if tests but we only keep the first few
2716 sparsity = 1
2717 else:
2718 # There are too many parameter combinations, so generate them sparsely
2719 sparsity_factor = 500
2720 sparsity = (
2721 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2722 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002723 else:
2724 # We have already limited test output combinations for 8k tests
2725 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002726
James Ward8b390432022-08-12 20:48:56 +01002727 arg_str = (
2728 "acc{}_st{}_kern{}_pad{}"
2729 if accum_dtypes[0] is not None
2730 else "st{}_kern{}_pad{}"
2731 )
2732
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002733 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002734 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002735 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002736
2737 # Support for larger values than 9 needs different delimiter
2738 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002739 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002740 delim.join([str(x) for x in stride]),
2741 delim.join([str(x) for x in kern]),
2742 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002743 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002744 args_dict = {
2745 "stride": stride,
2746 "pad": pad,
2747 "kernel": kern,
2748 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002749 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002750 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2751 }
James Ward8b390432022-08-12 20:48:56 +01002752
2753 if accum is not None:
2754 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002755 args_dict["acc_type"] = accum
2756 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002757
Jeremy Johnson87460262024-03-25 09:46:02 +00002758 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002759 n = 0
James Ward8b390432022-08-12 20:48:56 +01002760 for a in accum_dtypes:
2761 for s in sorted(list(strides)):
2762 for p in sorted(list(paddings)):
2763 for k in sorted(list(kernels)):
2764 if error_name in [
2765 ErrorIf.StrideSmallerOne,
2766 ErrorIf.KernelSmallerOne,
2767 ErrorIf.PadSmallerZero,
2768 ErrorIf.PadLargerEqualKernel,
2769 ]:
2770 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002771 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002772 )
James Ward8b390432022-08-12 20:48:56 +01002773 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002774 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002775 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002776 )
James Ward8b390432022-08-12 20:48:56 +01002777 elif (
Jeremy Johnson87460262024-03-25 09:46:02 +00002778 more_tests
2779 and n % sparsity == 0
James Ward8b390432022-08-12 20:48:56 +01002780 # padding must not exceed the kernel size
2781 and p[0] < k[0]
2782 and p[1] < k[0]
2783 and p[2] < k[1]
2784 and p[3] < k[1]
2785 # the padded shape must exceed the kernel size
2786 and (shape[1] + p[0] + p[1]) > k[0]
2787 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002788 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002789 partial_h = shape[1] + p[0] + p[1] - k[0]
2790 partial_w = shape[2] + p[2] + p[3] - k[1]
2791 remainder_h = partial_h % s[0]
2792 remainder_w = partial_w % s[1]
2793 output_h = partial_h // s[0] + 1
2794 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002795 logger.debug(
2796 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2797 )
James Ward8b390432022-08-12 20:48:56 +01002798 if (
2799 # the parameters must produce integer exact output
2800 error_name != ErrorIf.PoolingOutputShapeNonInteger
2801 and remainder_h == 0
2802 and remainder_w == 0
2803 ) or (
2804 error_name == ErrorIf.PoolingOutputShapeNonInteger
2805 and (remainder_h != 0 or remainder_w != 0)
2806 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002807 if (
2808 max_dim_size is not None
2809 and max(output_h, output_w) > max_dim_size
2810 ):
2811 # Test will consume too much memory - skip it
2812 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002813 # Dot products = N*OH*OW*C
2814 dp = gtu.product(
2815 (shape[0], output_h, output_w, shape[3])
2816 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002817 arg_list.append(
2818 get_arg_list_element(a, s, p, k, dp, shape)
2819 )
Jeremy Johnson87460262024-03-25 09:46:02 +00002820 if (
2821 error_name
2822 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
2823 ):
2824 # Found enough errors
2825 logger.debug(
2826 f"Skipping creating more pooling error tests for {error_name}"
2827 )
2828 more_tests = False
2829
James Ward8b390432022-08-12 20:48:56 +01002830 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002831
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002832 # Now add data generator types
2833 arg_list = TosaArgGen._add_data_generators(
2834 testGen,
2835 opName,
evacha019c96eef2024-02-07 11:21:55 +00002836 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002837 dtype,
2838 arg_list,
2839 error_name,
2840 )
2841
2842 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002843 return arg_list
2844
2845 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002846 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002847 arg_list = []
2848
2849 # Enumerate the output types here
2850 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002851 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002852 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002853 dtypeList = [
2854 DType.BOOL,
2855 DType.INT16,
2856 DType.INT32,
2857 DType.FP16,
2858 DType.BF16,
2859 DType.FP32,
2860 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002861 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002862 dtypeList = [
2863 DType.BOOL,
2864 DType.INT8,
2865 DType.INT32,
2866 DType.FP16,
2867 DType.BF16,
2868 DType.FP32,
2869 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002870 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002871 dtypeList = [
2872 DType.BOOL,
2873 DType.INT8,
2874 DType.INT16,
2875 DType.FP16,
2876 DType.BF16,
2877 DType.FP32,
2878 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002879 elif inDtype == DType.BOOL:
2880 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002881 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002882 dtypeList = [
2883 DType.INT8,
2884 DType.INT16,
2885 DType.INT32,
2886 DType.FP32,
2887 DType.FP8E4M3,
2888 DType.FP8E5M2,
2889 ]
James Ward24dbc422022-10-19 12:20:31 +01002890 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002891 dtypeList = [
2892 DType.INT8,
2893 DType.INT16,
2894 DType.INT32,
2895 DType.FP32,
2896 DType.FP8E4M3,
2897 DType.FP8E5M2,
2898 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002899 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002900 dtypeList = [
2901 DType.INT8,
2902 DType.INT16,
2903 DType.INT32,
2904 DType.FP16,
2905 DType.BF16,
2906 DType.FP8E4M3,
2907 DType.FP8E5M2,
2908 ]
2909 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2910 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002911 elif error_name == ErrorIf.WrongInputType:
2912 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002913 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002914 else:
2915 raise Exception("Unexpected input dtype: {}".format(inDtype))
2916
2917 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002918 arg_list.append(
2919 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2920 )
2921
2922 # Now add data generator types
2923 arg_list = TosaArgGen._add_data_generators(
2924 testGen,
2925 opName,
evacha019c96eef2024-02-07 11:21:55 +00002926 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002927 dtype,
2928 arg_list,
2929 error_name,
2930 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002931
2932 return arg_list
2933
2934 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002935 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002936 arg_list = []
2937
2938 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002939 for outDtype in [
2940 DType.UINT8,
2941 DType.INT8,
2942 DType.INT16,
2943 DType.INT32,
2944 DType.UINT16,
2945 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002946 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002947 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002948 and error_name == ErrorIf.OutputZeroPointNotZero
2949 ):
2950 continue
2951 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002952 outDtype != DType.UINT16
2953 and error_name == ErrorIf.U16OutputZeroPointNotValid
2954 ) or (
2955 inDtype != DType.UINT16
2956 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002957 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002958 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002959 continue
2960 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002961 inDtype == DType.UINT8
2962 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002963 and error_name != ErrorIf.WrongOutputType
2964 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002965 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2966 continue
2967 if (
2968 inDtype not in [DType.INT8, DType.INT16]
2969 and outDtype == DType.UINT8
2970 and error_name != ErrorIf.WrongOutputType
2971 ):
2972 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2973 continue
2974 if (
2975 inDtype == DType.UINT16
2976 and outDtype != DType.INT16
2977 and error_name != ErrorIf.WrongOutputType
2978 ):
2979 # The only output dtype for UINT16 is INT16, skip all others
2980 continue
2981 if (
2982 inDtype != DType.INT16
2983 and outDtype == DType.UINT16
2984 and error_name != ErrorIf.WrongOutputType
2985 ):
2986 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002987 continue
2988 if (
2989 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002990 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002991 ):
2992 continue
2993
2994 for scale32 in [False, True]:
2995 if error_name == ErrorIf.ScaleTrue and not scale32:
2996 continue
2997 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2998 continue
2999 for double_round in [False, True]:
3000 if error_name == ErrorIf.ScaleNotTrue and not double_round:
3001 continue
3002 for per_channel in [False, True]:
3003
3004 if (
3005 inDtype == DType.INT48
3006 and scale32
3007 and error_name != ErrorIf.ScaleTrue
3008 ):
3009 # Illegal condition. Must be scale32=False
3010 continue
3011 if (
3012 double_round
3013 and not scale32
3014 and error_name != ErrorIf.ScaleNotTrue
3015 ):
3016 # Illegal condition. ERROR_IF(!scale32 && double_round)
3017 continue
3018
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003019 if per_channel:
3020 nc = shapeList[0][-1]
3021 else:
3022 nc = 1
3023
3024 in_type_width = gtu.dtypeWidth(inDtype)
3025 out_type_width = gtu.dtypeWidth(outDtype)
3026
3027 # Calculate scale based on:
3028 # scale = a *(2^output_width)/(2^input_width))
3029
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003030 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003031 scale_arr = a * np.float32(
3032 (1 << out_type_width) / (1 << in_type_width)
3033 )
3034
3035 if scale32:
3036 # Cap the scaling at 2^31 - 1 for scale32
3037 scale_arr = np.clip(
3038 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
3039 )
3040 else:
3041 # Cap the scaling at 2^15 - 1 for scale16
3042 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3043
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003044 logger.debug(
3045 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
3046 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003047
3048 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3049 shift_arr = np.int32(np.zeros(shape=[nc]))
3050 for i in range(nc):
3051 (
3052 multiplier_arr[i],
3053 shift_arr[i],
3054 ) = TosaQuantGen.computeMultiplierAndShift(
3055 scale_arr[i], scale32
3056 )
3057
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003058 arg_list.append(
3059 (
3060 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01003061 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003062 int(scale32),
3063 int(double_round),
3064 int(per_channel),
3065 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00003066 {
3067 "output_dtype": outDtype,
3068 "scale": scale32,
3069 "double_round": double_round,
3070 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003071 "multiplier": multiplier_arr,
3072 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003073 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003074 )
3075 )
3076
Jeremy Johnson587cc842024-02-08 11:45:44 +00003077 arg_list = TosaArgGen._add_data_generators(
3078 testGen,
3079 opName,
evacha019c96eef2024-02-07 11:21:55 +00003080 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003081 inDtype,
3082 arg_list,
3083 error_name,
3084 )
3085 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003086 return arg_list
3087
3088 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003089 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003090 arg_list = []
3091
3092 if dtype is DType.INT32:
3093 for p in range(testGen.args.num_rand_permutations):
3094
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003095 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003096 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003097 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003098 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003099
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003100 arg_list = TosaArgGen._add_data_generators(
3101 testGen,
3102 opName,
evacha019c96eef2024-02-07 11:21:55 +00003103 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003104 dtype,
3105 arg_list,
3106 error_name,
3107 )
3108 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003109 return arg_list
3110
3111 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003112 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003113 arg_list = []
3114
Jeremy Johnson587cc842024-02-08 11:45:44 +00003115 for round in (True, False):
3116 args_dict = {
3117 "round": round,
3118 }
3119 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003120
Jeremy Johnson587cc842024-02-08 11:45:44 +00003121 arg_list = TosaArgGen._add_data_generators(
3122 testGen,
3123 opName,
evacha019c96eef2024-02-07 11:21:55 +00003124 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003125 dtype,
3126 arg_list,
3127 error_name,
3128 )
3129 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003130 return arg_list
3131
Luke Hutton57287132023-02-06 14:54:18 +00003132 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003133 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003134 arg_list = []
3135
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003136 shape = shapeList[0]
3137 dot_products = gtu.product(shape)
3138 ks = 2 * shape[1] * shape[2] # 2*H*W
3139 for inverse in (True, False):
3140 args_dict = {
3141 "dot_products": dot_products,
3142 "shape": shape,
3143 "ks": ks,
3144 "acc_type": dtype,
3145 "inverse": inverse,
3146 }
3147 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003148
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003149 arg_list = TosaArgGen._add_data_generators(
3150 testGen,
3151 opName,
evacha019c96eef2024-02-07 11:21:55 +00003152 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003153 dtype,
3154 arg_list,
3155 error_name,
3156 )
3157 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003158 return arg_list
3159
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003160 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003161 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003162 arg_list = []
3163
3164 shape = shapeList[0]
3165 dot_products = gtu.product(shape)
3166 ks = shape[1] * shape[2] # H*W
3167 args_dict = {
3168 "dot_products": dot_products,
3169 "shape": shape,
3170 "ks": ks,
3171 "acc_type": dtype,
3172 }
3173 arg_list.append(("", args_dict))
3174
3175 arg_list = TosaArgGen._add_data_generators(
3176 testGen,
3177 opName,
evacha019c96eef2024-02-07 11:21:55 +00003178 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003179 dtype,
3180 arg_list,
3181 error_name,
3182 )
3183 # Return list of tuples: (arg_str, args_dict)
3184 return arg_list
3185
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003186 # Helper function for reshape. Gets some factors of a larger number.
3187 @staticmethod
3188 def getFactors(val, start=1):
3189 factors = []
3190
3191 for i in range(start, int(np.sqrt(val)) + 1):
3192 if (val % i) == 0:
3193 factors.append(i)
3194
3195 return factors
3196
3197 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003198 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003199 arg_list = []
3200
3201 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003202 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003203 factors = TosaArgGen.getFactors(totalElements)
3204
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003205 # Find new shapes up to the number of permutations asked for
3206 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003207 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003208 # Rank from 1 to MAX_TENSOR_RANK
3209 newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003210 if len(factors) < newRank:
3211 continue
3212
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003213 # escape_counter limits the generation of new shapes to a reasonable time
3214 for escape_counter in range(100):
3215
3216 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003217 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003218 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003219 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 for i in range(1, newRank):
3221 # pick rank-1 factors
3222 newShape.append(shuffledFactors[0])
3223 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003224 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003225 TosaArgGen.getFactors(remainingElements)
3226 )
3227 newShape.append(remainingElements)
3228
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003229 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003230 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003231 for name, args_dict in arg_list:
3232 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003233 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003234 break
3235
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003236 if not duplicate:
3237 outShape = "x".join([str(x) for x in newShape])
3238 arg_list.append(
3239 (
3240 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3241 {"new_shape": newShape},
3242 )
3243 )
3244 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003245 break
3246
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003247 # Now add data generator types
3248 arg_list = TosaArgGen._add_data_generators(
3249 testGen,
3250 opName,
evacha019c96eef2024-02-07 11:21:55 +00003251 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003252 dtype,
3253 arg_list,
3254 error_name,
3255 )
3256
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003257 return arg_list
3258
3259 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003260 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003261 arg_list = []
3262
3263 ifm_shape = shapeList[0]
3264
3265 if error_name == ErrorIf.IndexOutsideBounds:
3266 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3267 incorrect_small_index = range(-len(ifm_shape), 0)
3268 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3269 permutations.extend(
3270 [p for p in itertools.permutations(incorrect_small_index)]
3271 )
3272 elif error_name == ErrorIf.IndexUsedTwice:
3273 # Create list with a duplicated index
3274 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003275 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003276 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3277 permutations = [p for p in itertools.permutations(perm_range)]
3278
3279 else:
3280 # Get all permutations
3281 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3282
3283 # Limit to possible permutations from shape dimension or argument setting
3284 limit = min(len(permutations), testGen.args.num_rand_permutations)
3285
3286 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003287 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003288
3289 # Create list of required amount of permutations
3290 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003291 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003292 for p in range(limit)
3293 ]
evacha0198477222024-01-26 12:25:32 +00003294 # Now add data generator types
3295 arg_list = TosaArgGen._add_data_generators(
3296 testGen,
3297 opName,
evacha019c96eef2024-02-07 11:21:55 +00003298 shapeList,
evacha0198477222024-01-26 12:25:32 +00003299 dtype,
3300 arg_list,
3301 error_name,
3302 )
3303 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003304 return arg_list
3305
3306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003307 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003308 arg_list = []
3309
3310 ifm_shape = shapeList[0]
3311 rank = len(ifm_shape)
3312
3313 for p in range(testGen.args.num_rand_permutations):
3314 start = []
3315 size = []
3316
3317 valid = True
3318
3319 for i in range(rank):
3320 if ifm_shape[i] > 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003321 start.append(rng.randInt(0, ifm_shape[i]))
3322 size.append(rng.randInt(0, ifm_shape[i] - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003323
3324 # Invalid slice size?
3325 if size[i] == 0:
3326 valid = False
3327 else:
3328 start.append(0)
3329 size.append(1)
3330
3331 if valid:
3332 # If ERROR_IF test required then incorrect start, size will be returned
3333 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003334 rng, error_name, ifm_shape, start, size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003335 )
evacha017f7d4252024-01-24 12:08:09 +00003336 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3337 # Now add data generator types
3338 arg_list = TosaArgGen._add_data_generators(
3339 testGen,
3340 opName,
evacha019c96eef2024-02-07 11:21:55 +00003341 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003342 dtype,
3343 arg_list,
3344 error_name,
3345 )
3346 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347 return arg_list
3348
3349 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003350 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003351 arg_list = []
3352
3353 ifm_shape = shapeList[0]
3354 rank = len(ifm_shape)
3355
3356 for p in range(testGen.args.num_rand_permutations):
3357
3358 # Pick a few random, but small multiple values
3359 # because otherwise this has a tendency to generate
3360 # enormous tensors
3361 multiples = []
3362 for i in range(rank):
3363 if ifm_shape[i] > 1000:
3364 # Multiple of 1 if ifm_shape dimension is large to reduce
3365 # tensor size
3366 multiples.append(1)
3367 elif max(ifm_shape) > 1000:
3368 multiples.append(2)
3369 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003370 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003371 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003372
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003373 # Now add data generator types
3374 arg_list = TosaArgGen._add_data_generators(
3375 testGen,
3376 opName,
evacha019c96eef2024-02-07 11:21:55 +00003377 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003378 dtype,
3379 arg_list,
3380 error_name,
3381 )
3382 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003383 return arg_list
3384
3385 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003386 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003387 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003388 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003389
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003390 def get_aspect_ratio_resize_params():
3391 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003392 aspect_ratio = rng.choice(common_aspect_ratios)
3393 invert = rng.choice((False, True))
3394 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003395
3396 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3397 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3398 scale_y_d = scale_x_d = 1
3399 offset_x = offset_y = 0
3400
3401 if letterbox:
3402 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003403 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003404 border_x = 0
3405 else:
3406 # Pillarboxing
3407 border_y = 0
3408 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003409 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003410
3411 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3412 offset = (offset_y, offset_x)
3413 border = (border_y, border_x)
3414
3415 return scale, offset, border
3416
3417 def get_upscale_downscale_params():
3418 valid_params = False
3419 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003420 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003421
3422 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003423 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003424
3425 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003426 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003427 scale_x_d = scale_y_d = 1
3428 scale_x_n = scale_y_n = (
3429 1 << shift if origin_sampling else 2 << shift
3430 )
3431 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3432 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3433 else:
3434 scale_x_n = 1
3435 scale_y_n = 1
3436
3437 # Return list of valid scale_*_d values (max value 4) given input dim shape
3438 def get_valid_denom(ifm_dim):
3439 return [x for x in range(1, 5) if ifm_dim % x == 1]
3440
3441 # Generate list of valid downscale values and choose one randomly
3442 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3443 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3444
3445 if not valid_scale_y_ds and not valid_scale_x_ds:
3446 # Bad parameters, skip
3447 continue
3448
3449 if not valid_scale_y_ds:
3450 scale_y_d = 1
3451 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003452 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003453
3454 if not valid_scale_x_ds:
3455 scale_x_d = 1
3456 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003457 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003458
3459 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003460 offset_y = rng.randInt(0, 16 * scale_y_n)
3461 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003462 valid_params = True
3463
3464 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3465 offset = (offset_y, offset_x)
3466 border = (border_y, border_x)
3467 return scale, offset, border
3468
3469 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003470 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3471 scale = scale_n / scale_d
3472 if scale > max_scale:
3473 factor = scale / max_scale
3474 new_scale_d = math.ceil(scale_d * factor)
3475 assert scale_n / new_scale_d <= max_scale
3476 scale_d = new_scale_d
3477 return scale_d
3478
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003479 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003480 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3481 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003482
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003483 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3484 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003485
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003486 scale_y_d = fix_scale_to_max_scale(
3487 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3488 )
3489 scale_x_d = fix_scale_to_max_scale(
3490 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3491 )
3492
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003493 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003494 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3495 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3496 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3497 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003498
3499 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3500 offset = (offset_y, offset_x)
3501 border = (border_y, border_x)
3502 return scale, offset, border
3503
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003504 def get_level_8k_params():
3505 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003506 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003507 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3508 )
3509 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3510 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003511 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003512 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003513 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003514 if switch:
3515 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3516 else:
3517 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3518
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003519 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3520 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003521 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003522 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3523 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003524 border = (border_y, border_x)
3525 return scale, offset, border
3526
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003527 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003528 # Exclude illegal {mode, type} configurations. Pick legal output types
3529 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3530 outputDTypeList = [DType.INT8]
3531 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3532 outputDTypeList = [DType.INT16]
3533 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3534 outputDTypeList = [DType.INT32]
3535 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3536 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003537 elif dtype == DType.FP16:
3538 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003539 elif dtype == DType.BF16:
3540 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003541 elif dtype == DType.FP32:
3542 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003543 elif dtype == DType.FP8E4M3:
3544 outputDTypeList = [DType.FP8E4M3]
3545 elif dtype == DType.FP8E5M2:
3546 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003547 elif error_name == ErrorIf.WrongInputType:
3548 # If an incorrect input type is used then we set a 'correct'
3549 # output type to avoid other errors
3550 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3551 else:
3552 continue
3553
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003554 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3555
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003556 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003557 perm = 0
3558 while perm < testGen.args.num_rand_permutations:
3559 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003560 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003561 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003562 (
3563 get_rand_params,
3564 get_upscale_downscale_params,
3565 get_aspect_ratio_resize_params,
3566 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003568 scale, offset, border = _rnd_param_fn()
3569 else:
3570 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003571
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003572 # Expand params for bounds-checking
3573 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3574 (offset_y, offset_x) = offset
3575 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003576
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003577 # Make sure output dimensions OH and OW are integers
3578 partial_output_y = (
3579 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3580 )
3581 partial_output_x = (
3582 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3583 )
3584 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003585 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003586 if (
3587 partial_output_y % scale_y_d == 0
3588 and partial_output_x % scale_x_d == 0
3589 ):
3590 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003591 if perm > 0:
3592 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003593 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003594 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003595 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003596 while partial_output_y % scale_y_d != 0:
3597 scale_y_d -= 1
3598 while partial_output_x % scale_x_d != 0:
3599 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003600 # Make sure we are still within max scaling
3601 if (
3602 scale_y_n / scale_y_d
3603 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3604 scale_x_n / scale_x_d
3605 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3606 # Skip the test as it is using too large a scaling factor
3607 if perm > 0:
3608 perm += 1
3609 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003610
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003611 output_y = partial_output_y // scale_y_d + 1
3612 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003613
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003614 if (
3615 output_y >= testGen.args.max_resize_output_dim
3616 or output_x >= testGen.args.max_resize_output_dim
3617 ) and error_name is None:
3618 # Skip positive test if output dim will be too high
3619 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003620 if not testGen.args.level8k or perm > 0:
3621 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003622 continue
3623
3624 if (
3625 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003626 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003627 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003628 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003629 ):
3630 # Output dimensions out of scope
3631 if error_name is not None and perm > 0:
3632 # As long as we have one ERROR_IF test, don't worry
3633 # about creating all the other permutations
3634 perm += 1
3635 continue
3636
3637 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3638 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003639 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003640 and output_y - scale_y_d < 1
3641 )
3642 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003643 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003644 and output_x - scale_x_d < 1
3645 )
3646 ):
3647 # Can't create a negative test with these params as it
3648 # will create invalid output size
3649 if perm > 0:
3650 perm += 1
3651 continue
3652
3653 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3654 offset = [offset_y, offset_x]
3655 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003656
3657 # Common for all data types
3658 if error_name is not None:
3659 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003660 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003661 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003662 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003663 outputDTypeNew,
3664 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003665 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003666 error_name,
3667 mode,
3668 dtype,
3669 shapeList,
3670 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003671 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003673 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003674 )
3675 else:
3676 outputDTypeNew = outputDType
3677
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003678 arg_to_append = (
3679 arg_str.format(
3680 "N" if mode == ResizeMode.NEAREST else "B",
3681 testGen.typeStr(outputDTypeNew),
3682 scale[0],
3683 scale[1],
3684 scale[2],
3685 scale[3],
3686 offset[0],
3687 offset[1],
3688 border[0],
3689 border[1],
3690 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003691 {
3692 "mode": mode,
3693 "scale": scale,
3694 "offset": offset,
3695 "border": border,
3696 "output_dtype": outputDTypeNew,
3697 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003698 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003699 if arg_to_append in arg_list:
3700 # Skip already generated test params
3701 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003702
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003703 # Valid permutation
3704 perm += 1
3705 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003706
3707 # Now add data generator types
3708 arg_list = TosaArgGen._add_data_generators(
3709 testGen,
3710 opName,
evacha019c96eef2024-02-07 11:21:55 +00003711 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003712 dtype,
3713 arg_list,
3714 error_name,
3715 )
3716 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003717 return arg_list
3718
3719 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003720 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003721 arg_list = []
3722
3723 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003724 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003725 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003726 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003727 # Make sure all slopes are within REQUIRE min/max 16-bit int
3728 for idx in range(len(table) - 1):
3729 slope = table[idx + 1] - table[idx]
3730 # Alter the next table entry to force the slope to be ok
3731 if slope > 32767:
3732 table[idx + 1] -= slope - 32767
3733 if slope < -32768:
3734 table[idx + 1] -= slope + 32768
3735 slope = table[idx + 1] - table[idx]
3736 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003737 arg_list.append(
3738 (
3739 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003740 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003741 )
3742 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003743 # Now add data generator types
3744 arg_list = TosaArgGen._add_data_generators(
3745 testGen,
3746 opName,
evacha019c96eef2024-02-07 11:21:55 +00003747 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003748 dtype,
3749 arg_list,
3750 error_name,
3751 )
3752 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003753 return arg_list
3754
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003755 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003756 # CondIf generates the condition values here.
3757 # Convert to tensors in the build function, along with the
3758 # then and else blocks
3759 arg_list = []
3760
3761 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003762 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003763
Jeremy Johnson587cc842024-02-08 11:45:44 +00003764 # Now add data generator types
3765 arg_list = TosaArgGen._add_data_generators(
3766 testGen,
3767 opName,
evacha019c96eef2024-02-07 11:21:55 +00003768 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003769 dtype,
3770 arg_list,
3771 error_name,
3772 )
3773 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003774 return arg_list
3775
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003776 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003777 # While loop: 0 iterations, 1, more than 1
3778 arg_list = []
3779
Jeremy Johnson587cc842024-02-08 11:45:44 +00003780 for iterations in [0, 1, 4]:
3781 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003782
Jeremy Johnson587cc842024-02-08 11:45:44 +00003783 # Now add data generator types
3784 arg_list = TosaArgGen._add_data_generators(
3785 testGen,
3786 opName,
evacha019c96eef2024-02-07 11:21:55 +00003787 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003788 dtype,
3789 arg_list,
3790 error_name,
3791 )
3792 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003793 return arg_list