blob: 26dd6f948a0dbf2a2b915f0ce73dd3ce37e425d8 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000262 if rank == 0:
263 # No broadcasting possible for rank 0
264 return [[]] * num_shapes
265
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100266 shape = testGen.makeShape(rng, rank)
evacha014a205112024-03-08 16:39:24 +0000267 # Do not broadcast for some tests
268 if error_name is None and rng.randInt(high=100) < 10:
269 return [shape] * num_shapes
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100270 shape_list = []
271
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000272 # Choose any one of the inputs to broadcast
273 # Note for ERRORS: Simplifies OutputShaper code if we don't change first shape
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100274 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
275 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000276
Jeremy Johnson0a042992024-02-28 13:20:05 +0000277 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100278 shape_bcast = shape.copy()
279
Jerry Ge135c9552023-05-23 20:59:32 +0000280 # To test broadcasting, the chosen fuzz index dimension should not be 1
281 if shape_bcast[fuzz_idx] == 1:
282 shape_bcast[fuzz_idx] += 1
283
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 # If the chosen input, pick a random index to broadcast
285 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000286 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100288 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100289 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100290 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 )
292 if rank != 1:
293 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100294 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100295 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000296 elif error_name == ErrorIf.BroadcastShapesMismatch:
297 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100298 else:
299 shape_bcast[fuzz_idx] = 1
300
301 shape_list.append(shape_bcast)
302
303 return shape_list
304
305 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100306 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000307 pl, const = op["operands"]
308 num_shapes = pl + const
309 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100310 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000311 )
312
313 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100314 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000315 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100316 shape_list = TosaTensorGen._get_broadcast_shapes(
317 testGen, rng, 2, rank, error_name
318 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000319 # Add a single dimension tensor for shift
320 shape_list.append([1])
321 return shape_list
322
323 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100325 pl, const = op["operands"]
326
327 if error_name != ErrorIf.WrongRank:
328 assert rank == 4
329
330 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100331 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000332 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100333
334 # Constrict the overall size of the shape when creating ERROR_IF tests
335 if error_name:
336 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
337 ifm_shape, max_dim=24, max_items=10000
338 )
339
340 # Get the filter height/width from the operator parameters
341 filter_hw = op["filter"]
342
343 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100344 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100345
346 # The filter dimensions are OHWI
347 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
348
Jeremy Johnson5e36bde2024-03-14 16:56:10 +0000349 # The bias is OC or 1 if broadcastable
350 try:
351 if op["broadcastable_bias"]:
352 if rng.choice([True, False]):
353 ofm_depth = 1
354 except KeyError:
355 pass
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100356 bias_shape = np.asarray([ofm_depth])
357
358 return [ifm_shape, filter_shape, bias_shape]
359
360 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100361 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100362 pl, const = op["operands"]
363
364 if error_name != ErrorIf.WrongRank:
365 assert rank == 5
366
367 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100368 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000369 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100370
371 # Constrict the overall size of the shape when creating ERROR_IF tests
372 if error_name:
373 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
374 ifm_shape, max_dim=24, max_items=10000
375 )
376
377 # Get the filter depth/height/width from the operator parameters
378 filter_dhw = op["filter"]
379
380 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100381 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100382
383 # The filter dimensions are ODHWI
384 filter_shape = np.asarray(
385 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
386 )
387
388 # The bias is OC
389 bias_shape = np.asarray([ofm_channel])
390
391 return [ifm_shape, filter_shape, bias_shape]
392
393 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100394 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100395 pl, const = op["operands"]
396
397 if error_name != ErrorIf.WrongRank:
398 assert rank == 4
399
400 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100401 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 filter_hw = op["filter"]
412
413 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100414 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100415
416 # The filter dimensions are OHWI
417 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
418
419 # The bias is OC
420 bias_shape = np.asarray([ofm_depth])
421
422 return [ifm_shape, filter_shape, bias_shape]
423
424 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100425 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100426 pl, const = op["operands"]
427
428 if error_name != ErrorIf.WrongRank:
429 assert rank == 4
430 assert pl == 1 and const == 2
431
432 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100433 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000434 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100435
436 # Constrict the overall size of the shape when creating ERROR_IF tests
437 if error_name:
438 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
439 ifm_shape, max_dim=24, max_items=10000
440 )
441
442 # Get the filter height/width from the operator parameters
443 # Filter is KH, HW, C, M
444 filter_hw = op["filter"]
445
446 # Generate a random OFM depth, but don't let it get too big because
447 # the output depth is M * C
448 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100449 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100450 ) + 1
451
452 # The filter dimensions are HWCM
453 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
454
455 # The bias is M * C
456 bias_shape = np.asarray([ifm_shape[3] * filter_m])
457
458 return [ifm_shape, filter_shape, bias_shape]
459
460 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100461 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000462 pl, const = op["operands"]
463
464 if error_name != ErrorIf.WrongRank:
465 assert rank == 3
466 assert pl == 2 and const == 0
467
468 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100469 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000470
471 # Select nearest lower power of two from input height and width
472 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
473 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
474
475 # Constrict the overall size of the shape when creating ERROR_IF tests
476 if error_name:
477 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
478
479 # Generate an invalid kernel that is not a power of two
480 if error_name == ErrorIf.KernelNotPowerOfTwo:
481 inc_h = 2 if ifm_shape[1] == 1 else 1
482 inc_w = 2 if ifm_shape[2] == 1 else 1
483 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100484 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000485 ifm_shape[1] += selected_inc[0]
486 ifm_shape[2] += selected_inc[1]
487
488 ifm_shape = testGen.constrictBatchSize(ifm_shape)
489
490 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
491 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100492 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000493 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100494 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000495 ifm_shapes[modify_shape][modify_dim] *= 2
496
497 return [ifm_shapes[0], ifm_shapes[1]]
498
499 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100500 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000501 pl, const = op["operands"]
502
503 if error_name != ErrorIf.WrongRank:
504 assert rank == 3
505 assert pl == 1 and const == 0
506
507 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100508 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000509
510 # Select nearest lower power of two from input height and width
511 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
512 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
513
514 # Constrict the overall size of the shape when creating ERROR_IF tests
515 if error_name:
516 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
517
518 # Generate an invalid kernel that is not a power of two
519 if error_name == ErrorIf.KernelNotPowerOfTwo:
520 # We must increment by 2 if current size is 1
521 inc_h = 2 if ifm_shape[1] == 1 else 1
522 inc_w = 2 if ifm_shape[2] == 1 else 1
523 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100524 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000525 ifm_shape[1] += selected_inc[0]
526 ifm_shape[2] += selected_inc[1]
527
James Ward30124a82023-02-02 14:56:33 +0000528 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000529
530 return [ifm_shape]
531
532 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100533 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100534 pl, const = op["operands"]
535
536 if error_name != ErrorIf.WrongRank:
537 assert rank == 2
538
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100539 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100540
541 # Constrict the overall size of the shape when creating ERROR_IF tests
542 if error_name:
543 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
544
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100545 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100546 low=testGen.args.tensor_shape_range[0],
547 high=testGen.args.tensor_shape_range[1],
548 size=1,
549 )[0]
550 filter_shape = np.asarray([filter_oc, input_shape[1]])
551
552 bias_shape = np.asarray([filter_oc])
553
554 return [input_shape, filter_shape, bias_shape]
555
556 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100557 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100558 pl, const = op["operands"]
559
560 if error_name != ErrorIf.WrongRank:
561 assert rank == 3
562 assert pl == 2 and const == 0
563
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100564 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100565
566 # Constrict the overall size of the shape when creating ERROR_IF tests
567 if error_name:
568 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
569
570 # Get a random number for b_oc even if target shape is defined
571 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100572 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100573 low=testGen.args.tensor_shape_range[0],
574 high=testGen.args.tensor_shape_range[1],
575 size=1,
576 )
577 )[0]
578 # If N or H is large let b_oc be 1 to reduce output tensor size
579 if max(a_shape) > 1000:
580 b_oc = 1
581
582 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
583 return [a_shape, b_shape]
584
585 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100586 def tgConcat(testGen, rng, op, rank, error_name=None):
587 pl, const = op["operands"]
588 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100589
590 # Create extra tensors to concat.
591 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100592 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100593 shape_list = []
594 for i in range(pl + const + num_tensors):
595 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100596 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100597 wrongShape = shape.copy()
598
599 if remove and len(shape) > 1:
600 wrongShape = wrongShape[1:]
601 else:
602 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100603 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100604
605 shape_list.append(wrongShape)
606 else:
607 shape_list.append(shape.copy())
608
609 return shape_list
610
611 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100612 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100613 if error_name in [
614 ErrorIf.AxisSmallerZero,
615 ErrorIf.AxisLargerRank,
616 ErrorIf.ConcatInputRankMismatch,
617 ]:
618 return shapeList
619
620 # Split concat shape along axis to allow for multiple const inputs
621 # without making too many large tensors
622 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
623 # If axis can't be split we still need to invalidate other dimensions
624 if error_name == ErrorIf.ConcatInputDimMismatch:
625 for shape in shapeList[1:]:
626 # Negative test shapeLists are created individually for each test,
627 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100628 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100629 return shapeList
630
631 # Create copy of shape we are going to split (so we don't alter shapeList)
632 shape = shapeList[0].copy()
633 # Add original shape as first input
634 new_shapeList = [shape.copy()]
635 length_on_axis = shape[axis]
636 remaining_length = length_on_axis
637 for i in range(len(shapeList) - 2):
638 # Calculate split on axis and remaining value
639 split_shape_val = int(shape[axis] / 2)
640 remaining_length = remaining_length - split_shape_val
641
642 # Append new shape, and set remaining shape
643 shape[axis] = split_shape_val
644 new_shapeList.append(shape.copy())
645
646 # invalidate dimensions
647 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100648 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100649 else:
650 shape[axis] = remaining_length
651
652 if i == len(shapeList) - 3:
653 new_shapeList.append(shape.copy())
654
655 return new_shapeList
656
657
658class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100659 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100660
661 def __init__(self):
662 pass
663
Jeremy Johnson1271c442023-09-05 11:39:26 +0100664 class TVGInfo:
665 """Enhanced tensor values information including data gen dict."""
666
667 def __init__(self, tensorList, dataGenDict):
668 self.tensorList = tensorList
669 self.dataGenDict = dataGenDict
670
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100671 # Default high value for random numbers
672 TVG_FLOAT_HIGH_VALUE = {
673 DType.FP32: (1 << 128) - (1 << (127 - 23)),
674 DType.FP16: (1 << 16) - (1 << (15 - 10)),
675 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000676 DType.FP8E4M3: 448,
677 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100678 }
679
Jeremy Johnson30476252023-11-20 16:15:30 +0000680 # Default lowest normal values for random numbers
681 TVG_FLOAT_LOW_VALUE = {
682 DType.FP32: np.exp2(-126),
683 DType.FP16: np.exp2(-14),
684 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000685 DType.FP8E4M3: np.exp2(-9),
686 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000687 }
688
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100689 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100690 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000691 # Return a tuple of (low,high) data range values for the given data
692 # type using a combination of per operator table limits, data limits
693 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000694 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100695 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000696 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000697 if lowValueLookup is not None and dtype in lowValueLookup:
698 low_val = lowValueLookup[dtype]
699 else:
700 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000701 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000702 # respecting the default ranges if more/less than the low/high
703 # values
704 data_range = (
705 max(low_val, type_range[0]),
706 min(high_val, type_range[1]),
707 )
708 if data_range[0] > data_range[1]:
709 # Invalid data range from low to high created due to user
710 # constraints revert to using internal ranges as they are
711 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000712 logger.info(
713 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
714 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000715 data_range = (low_val, high_val)
716 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000717 return None
718
719 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100720 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100721 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100722 ):
723 # Variable inputs versus constants
724 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000725 if "p_count" in argsDict:
726 # Override for operators like CONCAT
727 pCount = argsDict["p_count"]
728 cCount = argsDict["c_count"]
729 assert pCount + cCount == len(
730 shapeList
731 ), "Placeholders & Constant tensors must match shapes list"
732
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000733 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100734
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100735 if (
736 error_name is not None
737 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100738 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100739 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000740 # Fall back to internal data gen when dealing with unsupported types or ops
741 data_range = argsDict["data_range"] if "data_range" in argsDict else None
742 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000743 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000744 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000745 if "data_range_list" in argsDict:
746 data_range = argsDict["data_range_list"][idx]["range"]
747 roundMode = (
748 "round" in argsDict["data_range_list"][idx]
749 and argsDict["data_range_list"][idx]["round"] is True
750 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000751 if data_range is not None and dtype not in (
752 DType.FP16,
753 DType.FP32,
754 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000755 DType.FP8E4M3,
756 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000757 ):
758 # Change from inclusive to exclusive range
759 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000760
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100761 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000762 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000763 if dtype == DType.SHAPE:
764 arr = np.int64(argsDict["fixed_data"][idx])
765 elif dtype == DType.INT8:
766 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000767 elif dtype == DType.INT16:
768 arr = np.int16(argsDict["fixed_data"][idx])
769 elif dtype == DType.INT32:
770 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000771 else:
772 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000773 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100774 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000775 if roundMode:
776 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000777 if idx < pCount:
778 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
779 else:
780 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100781
Jeremy Johnson1271c442023-09-05 11:39:26 +0100782 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
783
784 # Create data generator meta-data
785 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100786 tens_data = {
787 "version": "0.1",
788 "tensors": {},
789 }
790 dg_tens_meta = tens_data["tensors"]
evacha014a205112024-03-08 16:39:24 +0000791
792 fp_special_info = {}
793 fp_special_info["start_idx"] = int(rng.randInt())
794
Jeremy Johnson1271c442023-09-05 11:39:26 +0100795 for idx, shape in enumerate(shapeList):
796
797 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000798 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
799 tens_meta["generator"] = gtu.DataGenType(
800 gtu.DataGenType.FIXED_DATA
801 ).name
802 else:
803 tens_meta["generator"] = gtu.DataGenType(dg_type).name
804
Jeremy Johnson1271c442023-09-05 11:39:26 +0100805 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
806 tens_meta["shape"] = [int(i) for i in shape]
807 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100808 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100809
Jeremy Johnsonc870d1e2024-04-08 16:17:47 +0100810 if testGen.args.random_const_inputs:
811 # Choose type of tensor biased by defaults
812 percentage = rng.randInt(0, 100)
813 variable = (idx < pCount and percentage < 70) or (
814 idx >= pCount and percentage >= 70
815 )
816 else:
817 # Use default set up of constants versus inputs for the op
818 variable = idx < pCount
819
820 if variable:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100821 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100822 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100823 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100824
825 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
826 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000827 if (
828 tens_meta["generator"]
829 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
830 ):
831 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
832 tens_meta["fixed_data_info"] = info
833 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100834 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000835
Won Jeon64e4bfe2024-01-18 06:31:55 +0000836 data_range = None
837 if "data_range_list" in argsDict:
838 data_range = argsDict["data_range_list"][idx]["range"]
839 if "round" in argsDict["data_range_list"][idx]:
840 info["round"] = argsDict["data_range_list"][idx]["round"]
841 elif "data_range" in argsDict:
842 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000843
Won Jeon64e4bfe2024-01-18 06:31:55 +0000844 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100845 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000846 info["range"] = [str(v) for v in data_range]
847 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100848 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
849 info = {}
850 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100851 info["ks"] = int(argsDict["ks"])
852 if "acc_type" in argsDict:
853 # Convert type number into JSON name
854 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
855 "json"
856 ]
857 if "kernel" in argsDict:
858 info["kernel"] = [int(k) for k in argsDict["kernel"]]
859 if "axis" in argsDict:
860 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100861 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000862 elif dg_type == gtu.DataGenType.FULL_RANGE:
863 info = {}
864 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100865 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000866 )
867 tens_meta["full_range_info"] = info
evacha014a205112024-03-08 16:39:24 +0000868 elif dg_type == gtu.DataGenType.FP_SPECIAL:
869 tens_meta["fp_special_info"] = fp_special_info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100870 else:
871 # TODO - other data gen type
872 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100873
874 # Using the finished generate config meta data - generate the data if
875 # needed and assign a tensor name from the serializer
876
877 # Need to generate data when not lazy or for the bias tensor as we need
878 # to work out if the bias data is non-zero for compliance
879 if not testGen.args.lazy_data_gen or (
880 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
881 ):
882 # Give this tensor a temporary name until we get one from the serializer
883 temp_name = f"placeholder_{idx}"
884 dg_tens_meta[temp_name] = tens_meta
885 # Create data now using the temporary name to access meta details
886 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000887 if tens_meta["data_type"] == "SHAPE":
888 # Tensor type SHAPE and Numpy file type must be the same
889 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100890 # Remove the item as we will give it the correct name later
891 del dg_tens_meta[temp_name]
892
893 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
894 # The KS value used by compliance verification is altered when the
895 # bias data is non-zero
896 if max(abs(data)) > 0.0:
897 argsDict["ksb"] = argsDict["ks"] + 1
898
899 if testGen.args.lazy_data_gen:
900 data = None
901
Jeremy Johnsonc870d1e2024-04-08 16:17:47 +0100902 if variable:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100903 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
904 else:
905 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
906
907 tens_ser_list.append(tens)
908 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100909 dg_tens_meta[tens.name] = tens_meta
910
Jeremy Johnson1271c442023-09-05 11:39:26 +0100911 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
912
913 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100914 def tvgNegate(
915 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
916 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100917 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000918 # Integer test
919 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 pCount, cCount = op["operands"]
921 assert (
922 pCount == 1 and cCount == 0
923 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100924 # Must create tensors with values within accumulator (int32) negatable
925 # range
926 max_val = (1 << 31) - 1
927 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100928 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100929 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100930 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000931 tens_ser_list = []
932 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100933 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
934 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000935 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100936 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000937 # ERROR_IF or floating point test
938 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100939 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100940 )
941
Jeremy Johnson30476252023-11-20 16:15:30 +0000942 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000943 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
944 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
945 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
946 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
947 }
948
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100949 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100950 def tvgAddSub(
951 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
952 ):
Won Jeon74342e52024-01-09 00:34:40 +0000953 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000954 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100955 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000956 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100957 pCount, cCount = op["operands"]
958 assert (
959 pCount == 2 and cCount == 0
960 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000961 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000962 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
Jeremy Johnson32bf9012024-03-20 16:32:23 +0000963 data_range = None # Use default
964 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE):
965 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100966 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
967 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100968 if add:
969 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
970 else:
971 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
972
973 # Work out the saturation limits
974 max_i32 = (1 << 31) - 1
975 min_i32 = -(1 << 31)
976 max_arr = np.full(shapeList[1], max_i32)
977 min_arr = np.full(shapeList[1], min_i32)
978
979 # Find how much values exceed the maximum/minimums
980 sat_max_arr = np.maximum(res_arr - max_arr, 0)
981 sat_min_arr = np.minimum(res_arr - min_arr, 0)
982
983 if not add:
984 # Swap saturation values and negate values as we need to perform opposite operations
985 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
986
987 # Create new array of unsaturated values by clipping values as needed
988 b_unsat_arr = b_arr
989 if (sat_max_arr != 0).any():
990 # Clip values that cause saturation
991 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
992 # Reduce axes in unsaturated tensor to match original tensor
993 for axis, dim in enumerate(b_arr.shape):
994 if dim != b_unsat_arr.shape[axis]:
995 assert (
996 dim == 1
997 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
998 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
999
1000 if (sat_min_arr != 0).any():
1001 # Clip values that cause saturation
1002 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
1003 # Reduce axes in unsaturated tensor to match original tensor
1004 for axis, dim in enumerate(b_arr.shape):
1005 if dim != b_unsat_arr.shape[axis]:
1006 assert (
1007 dim == 1
1008 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
1009 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
1010
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001011 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001012 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1013 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001014 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001015 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
1016 )
1017
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001018 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001020 # ERROR_IF or floating point test
1021 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001022 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001023 )
1024 if data_range:
1025 argsDict["data_range"] = data_range
1026
1027 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001028 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001029 )
1030
1031 @staticmethod
1032 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001033 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001034 ):
1035 if dtypeList[0] in (
1036 DType.INT32,
1037 DType.INT16,
1038 DType.INT8,
1039 ):
1040 # Limit input tensors with cond_if_binary or while_loop to stop
1041 # saturation of add/sub ops with int32 and keep all logical shift
1042 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001043 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044 pCount, cCount = op["operands"]
1045 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001046 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001047 for idx, shape in enumerate(shapeList[:]):
1048 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001049 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001050 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001051 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001052 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001053 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001054 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1055 )
1056 pRemain -= 1
1057 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001058 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059 testGen.ser.addConst(shape, dtypeList[idx], arr)
1060 )
1061
Jeremy Johnson587cc842024-02-08 11:45:44 +00001062 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001063 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001064 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001065 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001066 )
1067
1068 @staticmethod
1069 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001070 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001071 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001072 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001073 pCount, cCount = op["operands"]
1074 # Force value of operand[1] to be within [0, num_bits]
1075 assert (
1076 pCount == 2 and cCount == 0
1077 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1078
Jeremy Johnson587cc842024-02-08 11:45:44 +00001079 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001080 for idx, shape in enumerate(shapeList[:]):
1081 if idx == 1:
1082 if dtypeList[idx] == DType.INT8:
Jeremy Johnsone0ded592024-04-15 11:21:32 +01001083 arr = np.int8(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001084 elif dtypeList[idx] == DType.INT16:
Jeremy Johnsone0ded592024-04-15 11:21:32 +01001085 arr = np.int16(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001086 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001087 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001088 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001089 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001090 else:
1091 raise Exception("OpArithmeticRightShift: invalid input dtype")
1092 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001093 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001094 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001095
Jeremy Johnson587cc842024-02-08 11:45:44 +00001096 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097
1098 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001099 def tvgReshape(
1100 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1101 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001102 dtypeList[1] = DType.SHAPE
1103 shapeList[1] = [len(argsDict["new_shape"])]
1104 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1105 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1106
1107 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001108 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001109 )
1110
1111 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001112 def tvgRescale(
1113 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1114 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001115 scale32 = argsDict["scale"]
1116 multiplier_arr = argsDict["multiplier"]
1117 shift_arr = argsDict["shift"]
1118
1119 if scale32:
1120 dtypeList[1] = DType.INT32
1121 else:
1122 dtypeList[1] = DType.INT16
1123 shapeList[1] = [len(multiplier_arr)]
1124 dtypeList[2] = DType.INT8
1125 shapeList[2] = [len(shift_arr)]
1126 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1127 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1128
1129 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001130 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001131 )
1132
1133 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001134 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001135 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1136 pad_values = argsDict["pad"].flatten()
1137 dtypeList[1] = DType.SHAPE
1138 shapeList[1] = [len(pad_values)]
1139 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1140 argsDict["fixed_data"] = [None, pad_values]
1141
1142 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001143 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001144 )
1145
1146 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001147 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001148 dtypeList[1] = DType.SHAPE
1149 shapeList[1] = [len(argsDict["start"])]
1150 dtypeList[2] = DType.SHAPE
1151 shapeList[2] = [len(argsDict["size"])]
1152 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1153 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1154
1155 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001156 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001157 )
1158
1159 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001160 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001161 dtypeList[1] = DType.SHAPE
1162 shapeList[1] = [len(argsDict["multiples"])]
1163 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1164
1165 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001166 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001167 )
1168
1169 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001170 def tvgSelect(
1171 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1172 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001173 # Set datatype of condition tensor to boolean
1174 dtypeList[0] = DType.BOOL
1175
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001176 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001177 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001178 )
1179
1180 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001181 def tvgIntDiv(
1182 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1183 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001184 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001185 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001186 pCount, cCount = op["operands"]
1187 assert (
1188 pCount == 2 and cCount == 0
1189 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1190
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001191 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001192
1193 # Two invalid cases for Op.INTDIV:
1194 # 1. divisor == 0
1195 # 2. dividend == -(1<<31) and divisor == -1
1196 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001197 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1198 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001199
1200 if (divisor_arr == 0).any():
1201 continue
1202
1203 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1204 continue
1205
1206 break
1207
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001208 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001209 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1210 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001211 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001212 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1213 )
1214
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001215 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001217 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001218 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001219 )
1220
Jeremy Johnson30476252023-11-20 16:15:30 +00001221 # Set the MUL data range to the square root of the largest value
1222 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001223 TVG_FLOAT_HIGH_VALUE_MUL = {
1224 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1225 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1226 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1227 }
1228
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001230 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001231 if error_name is not None or dtypeList[0] in (
1232 DType.FP16,
1233 DType.BF16,
1234 DType.FP32,
1235 ):
1236 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001237 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001238 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001239 )
1240 if data_range:
1241 argsDict["data_range"] = data_range
1242
Jeremy Johnson0a042992024-02-28 13:20:05 +00001243 if dtypeList[0] != DType.SHAPE:
1244 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1245 dtypeList[2] = DType.INT8
1246 shapeList[2] = [1]
1247 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1248 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1249
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001250 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001251 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001252 )
1253 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001254 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001255 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001256
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001257 tens_ser_list = []
1258
1259 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001260 if dtypeList[0] == DType.SHAPE:
1261 shift = 0
1262 else:
1263 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001264 if dtypeList[0] == DType.INT8:
1265 num_bits = 8
1266 elif dtypeList[0] == DType.INT16:
1267 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001268 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001269 num_bits = 32
1270 elif error_name == ErrorIf.WrongInputType:
1271 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001272 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001273 raise Exception(
1274 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1275 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001276
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001277 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001278 if dtypeList[idx] == DType.SHAPE:
1279 low = testGen.args.tensor_shape_range[0]
1280 high = testGen.args.tensor_shape_range[1]
1281 else:
1282 low = -(2 ** (num_bits - 1))
1283 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001284
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001285 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1286 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001287
1288 i = 0
1289 while True:
1290
1291 a_arr_64 = a_arr.astype(np.int64)
1292 b_arr_64 = b_arr.astype(np.int64)
1293
1294 if shift > 0:
1295 rounding = 1 << (shift - 1)
1296 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001297 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001298 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001299
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001300 if (result_arr > -(2**31)).all() and (
1301 result_arr <= ((2**31) - 1)
1302 ).all():
1303 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001304
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001305 i = i + 1
1306 a_arr = a_arr // 2
1307 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001308
Won Jeon74342e52024-01-09 00:34:40 +00001309 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001310 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001311 tens_ser_list.append(
1312 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1313 )
1314 tens_ser_list.append(
1315 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1316 )
1317 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001318 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001319 tens_ser_list.append(
Jeremy Johnson18a379d2024-03-28 15:53:21 +00001320 testGen.ser.addPlaceholder(
1321 shapeList[0], dtypeList[0], a_arr.astype(np.int32)
1322 )
Won Jeon74342e52024-01-09 00:34:40 +00001323 )
1324 tens_ser_list.append(
Jeremy Johnson18a379d2024-03-28 15:53:21 +00001325 testGen.ser.addPlaceholder(
1326 shapeList[1], dtypeList[1], b_arr.astype(np.int32)
1327 )
Won Jeon74342e52024-01-09 00:34:40 +00001328 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001329 tens_ser_list.append(
1330 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1331 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001332
1333 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001334
1335 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001336 def tvgConcat(
1337 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1338 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001339 count = len(shapeList) - testGen.args.num_const_inputs_concat
1340 if count < 1:
1341 count = 1
1342 if testGen.args.num_const_inputs_concat == 0:
1343 count = len(shapeList)
1344
Won Jeon74342e52024-01-09 00:34:40 +00001345 op = testGen.TOSA_OP_LIST[opName]
1346 if op["op"] == Op.CONCAT_SHAPE:
1347 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001348 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001349 else:
1350 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001351 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001352 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001353
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001354 # Override default pCount/cCount for operator
1355 argsDict["p_count"] = count
1356 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001357
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001358 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001359 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001360 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001361
1362 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001363 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001364 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001365 ):
1366 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001367 pCount, cCount = op["operands"]
1368 assert (
1369 pCount == 2 and cCount == 0
1370 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001371 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1372 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001373 tens_ser_list = []
1374 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001375 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1376 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001377 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001378 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1379 )
1380
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001381 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001382
1383 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001384 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001385 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1386 # Integer
1387 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001388 pCount, cCount = op["operands"]
1389 assert (
1390 pCount == 2 and cCount == 0
1391 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001392
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001393 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1394 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001395
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001396 # Using random numbers means that it will be very unlikely that
1397 # there are any matching (equal) values, therefore force that
1398 # there are twice the number of matching values as the tensor rank
1399 for num in range(0, len(shapeList[0]) * 2):
1400 a_index = []
1401 b_index = []
1402 # Choose an index in each axis for the whole shape
1403 for axis in range(0, len(shapeList[0])):
1404 # Index can be up to the largest dimension in both shapes
1405 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001406 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001407 )
1408 # Reduce the index down to a shape's dim for broadcasting
1409 a_index.append(min(shapeList[0][axis] - 1, index))
1410 b_index.append(min(shapeList[1][axis] - 1, index))
1411
1412 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1413
Jeremy Johnsona0150012023-11-15 15:52:06 +00001414 tens_ser_list = []
1415 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001416 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1417 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001418 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001419 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1420 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001421 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001422 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001423 # ERROR_IF or floating point test
1424 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001425 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001426 )
1427
1428 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001429 def tvgReduceSum(
1430 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1431 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001432 dtype = dtypeList[0]
1433 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001434 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001435 pCount, cCount = op["operands"]
1436 assert (
1437 pCount == 1 and cCount == 0
1438 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1439 # Limit values so that the sum cannot exceed the range of an int32 during
1440 # summation of any axis
1441 range_val = int((1 << 31) / max(shapeList[0]))
1442 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001443 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001444 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001445 tens_ser_list = []
1446 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001447 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001448 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001449 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001450 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001451 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001452 if (
1453 error_name is None
1454 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1455 ):
1456 # Limit ranges for (non error & non compliance) tests by using
1457 # values that can be summed on any axis to not hit infinity
1458 highval_lookup = {
1459 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1460 / max(shapeList[0])
1461 }
1462 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001463 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001464 )
1465 assert data_range is not None
1466 argsDict["data_range"] = data_range
1467
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001468 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001469 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001470 )
1471
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001472 @staticmethod
1473 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001474 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001475 ):
1476 dtype = dtypeList[0]
1477 if error_name is None:
1478 # Limit ranges for (non error) tests by using
1479 # values that can be multiplied on any axis to not hit infinity
1480 highval_lookup = {
1481 dtype: math.pow(
1482 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1483 1 / max(shapeList[0]),
1484 )
1485 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001486 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001487 assert data_range is not None
1488 argsDict["data_range"] = data_range
1489
1490 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001491 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001492 )
1493
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001494 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001495 def tvgResize(
1496 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1497 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001498 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001499 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001500 dtypeList[0],
1501 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1502 )
1503 if data_range:
1504 argsDict["data_range"] = data_range
1505 # Needed for compliance
1506 argsDict["max_abs_value"] = data_range[1]
1507
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001508 scale_values = argsDict["scale"]
1509 offset_values = argsDict["offset"]
1510 border_values = argsDict["border"]
1511 dtypeList[1] = DType.SHAPE
1512 dtypeList[2] = DType.SHAPE
1513 dtypeList[3] = DType.SHAPE
1514 shapeList[1] = [len(scale_values)]
1515 shapeList[2] = [len(offset_values)]
1516 shapeList[3] = [len(border_values)]
1517 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1518
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001519 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001520 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001521 )
1522
Jeremy Johnson30476252023-11-20 16:15:30 +00001523 # Set the POW exponent high data range
1524 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1525 DType.FP32: 10.0,
1526 DType.FP16: 10.0,
1527 DType.BF16: 10.0,
1528 }
1529 # POW highest base value (within a safe margin of error) that can be raised
1530 # to +ve exponent that doesn't become Infinity
1531 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1532 DType.FP32: math.floor(
1533 math.pow(
1534 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1535 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1536 )
1537 ),
1538 DType.FP16: math.floor(
1539 math.pow(
1540 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1541 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1542 )
1543 ),
1544 DType.BF16: math.floor(
1545 math.pow(
1546 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1547 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1548 )
1549 ),
1550 }
1551 # POW lowest base value (within a safe margin of error) that can be raised
1552 # to -ve exponent that doesn't become Infinity
1553 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1554 DType.FP32: math.ceil(
1555 math.pow(
1556 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1557 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1558 )
1559 * 1000
1560 )
1561 / 1000,
1562 DType.FP16: math.ceil(
1563 math.pow(
1564 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1565 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1566 )
1567 * 1000
1568 )
1569 / 1000,
1570 DType.BF16: math.ceil(
1571 math.pow(
1572 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1573 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1574 )
1575 * 1000
1576 )
1577 / 1000,
1578 }
1579
1580 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001581 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001582 if error_name is not None:
1583 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001584 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001585 )
1586 dtype = dtypeList[0]
1587 # Different ranges for POW
1588 test_set = argsDict["s"]
1589 if test_set == 0:
1590 # Positive base with fractional exponent
1591 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001592 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001593 dtype,
1594 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1595 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1596 )
1597 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001598 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001599 )
1600 exp_round = False
1601 else:
1602 # Integer exponent
1603 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001604 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001605 )
1606 exp_round = True
1607 if test_set == 1:
1608 # Positive base
1609 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001610 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001611 dtype,
1612 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1613 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1614 )
1615 else:
1616 assert test_set == 2
1617 # Negative base
1618 # Supply new look up tables with negative values
1619 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001620 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001621 dtype,
1622 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1623 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1624 )
1625
1626 data_range_list = (
1627 {
1628 "range": base_range,
1629 },
1630 {
1631 "range": exp_range,
1632 "round": exp_round,
1633 },
1634 )
1635 argsDict["data_range_list"] = data_range_list
1636 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001637 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001638 )
1639
1640 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001641 def tvgLogRsqrt(
1642 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1643 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001644 # LOG & RSQRT data range from lowest expressible positive number to
1645 # largest to avoid NaNs
1646 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001647 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001648 dtypeList[0],
1649 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1650 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1651 )
1652 if data_range:
1653 argsDict["data_range"] = data_range
1654
1655 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001656 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001657 )
1658
1659 # Set the EXP data range to the log of the largest to smallest values
1660 # to avoid infinities or making the result zero
1661 TVG_FLOAT_HIGH_VALUE_EXP = {
1662 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1663 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1664 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1665 }
1666 TVG_FLOAT_LOW_VALUE_EXP = {
1667 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1668 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1669 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1670 }
1671
1672 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001673 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001674 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001675 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001676 dtypeList[0],
1677 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1678 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1679 )
1680 if data_range:
1681 argsDict["data_range"] = data_range
1682
1683 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001684 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001685 )
1686
1687 @staticmethod
1688 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001689 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001690 ):
1691 dtype = dtypeList[0]
1692 if (
1693 error_name is None
1694 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001695 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001696 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001697 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001698 # Limit ranges for (non error & non compliance) FP tests by using
1699 # values that can be multiplied on any axis to not hit infinity/NaN
1700 IC = shapeList[0][1]
1701 highval_lookup = {
1702 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1703 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001704 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001705 assert data_range is not None
1706 argsDict["data_range"] = data_range
1707
1708 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001709 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001710 )
1711
Jeremy Johnson708da822023-11-15 16:25:45 +00001712 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001713 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001714 in_dtype = dtypeList[0]
1715 out_dtype = argsDict["out_type"]
1716 # Create look up to limit input tensor to output type maximums to avoid
1717 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001718 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001719 highval_lookup = {in_dtype: out_range[1]}
1720 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001721 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001722 in_dtype,
1723 highval_lookup,
1724 )
1725
1726 assert data_range is not None
1727 argsDict["data_range"] = data_range
1728
1729 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001730 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001731 )
1732
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001733 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001734 def tvgGather(
1735 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1736 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001737 K = shapeList[0][1]
1738
1739 # Fix the type of the indices tensor
1740 dtypeList[1] = DType.INT32
1741
1742 dtype = dtypeList[0]
1743 if not gtu.dtypeIsSupportedByCompliance(dtype):
1744 # Test unsupported by data generator
1745 op = testGen.TOSA_OP_LIST[opName]
1746 pCount, cCount = op["operands"]
1747 assert (
1748 pCount == 2 and cCount == 0
1749 ), "Op.GATHER must have 2 placeholders, 0 consts"
1750
1751 tens_ser_list = []
1752 for idx, shape in enumerate(shapeList):
1753 dtype = dtypeList[idx]
1754 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001755 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001756 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1757 else:
1758 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001759 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001760 # To match old functionality - create indices as CONST
1761 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1762
1763 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1764
1765 else:
1766 # ERROR_IF or floating point test
1767 # Use inclusive values upto index K for indices tensor
1768 data_range_list = (
1769 {"range": None},
1770 {"range": (0, K - 1)},
1771 )
1772 argsDict["data_range_list"] = data_range_list
1773
1774 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001775 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001776 )
1777
1778 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001779 def tvgScatter(
1780 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1781 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001782 K = shapeList[0][1]
1783 W = shapeList[2][1]
1784
1785 # Work out an indices tensor here with data that doesn't exceed the
1786 # dimension K of the values_in tensor and does NOT repeat the same K
1787 # location as needed by the spec:
1788 # "It is not permitted to repeat the same output index within a single
1789 # SCATTER operation and so each output index occurs at most once."
1790 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1791
1792 # Fix the type of the indices tensor
1793 dtypeList[1] = DType.INT32
1794
1795 dtype = dtypeList[0]
1796 if not gtu.dtypeIsSupportedByCompliance(dtype):
1797 # Test unsupported by data generator
1798 op = testGen.TOSA_OP_LIST[opName]
1799 pCount, cCount = op["operands"]
1800 assert (
1801 pCount == 3 and cCount == 0
1802 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1803
1804 tens_ser_list = []
1805 for idx, shape in enumerate(shapeList):
1806 dtype = dtypeList[idx]
1807 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001808 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001809 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1810 else:
1811 # Create the indices array
1812 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1813 arr = []
1814 for n in range(shape[0]):
1815 # Get a shuffled list of output indices (0 to K-1) and
1816 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001817 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001818 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1819 # To match old functionality - create indices as CONST
1820 tens_ser_list.append(
1821 testGen.ser.addConst(shape, dtype, indices_arr)
1822 )
1823
1824 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1825
1826 else:
1827 # ERROR_IF or floating point test
1828 # Use inclusive values upto index K for indices tensor
1829 data_range_list = (
1830 {"range": None},
1831 {"range": (0, K - 1)},
1832 {"range": None},
1833 )
1834 argsDict["data_range_list"] = data_range_list
1835
1836 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001837 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001838 )
1839
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001840
1841class TosaArgGen:
1842 """Argument generators create exhaustive or random lists of attributes for
1843 operators that take attributes or other parameters.
1844
1845 The return value is a list of (descriptive_name, [arglist]) tuples where
1846 the descriptive_name is appended to the test name and the arglist is expanded
1847 as arguments to the operator build function.
1848 """
1849
1850 def __init__(self):
1851 pass
1852
1853 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001854 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001855 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001856 if (
1857 error_name is None
1858 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1859 and gtu.dtypeIsSupportedByCompliance(dtype)
1860 ):
evacha01ad8e1e22024-03-19 12:42:17 +00001861 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"].get(
1862 dtype, (gtu.DataGenType.PSEUDO_RANDOM,)
1863 )
1864
Jeremy Johnson1271c442023-09-05 11:39:26 +01001865 else:
1866 # Error test or No data generator types listed - assume random
1867 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1868
1869 # Expand arg list with other data generator types
1870 new_arg_list = []
1871 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001872 for arg_str, args_dict in arg_list:
evacha01ad8e1e22024-03-19 12:42:17 +00001873 gen_args_dict = args_dict.copy()
evacha014a205112024-03-08 16:39:24 +00001874 # Only create one test by default - no sets of tests
1875 num_test_sets = 0
1876
Jeremy Johnson1271c442023-09-05 11:39:26 +01001877 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001878 if error_name is None:
evacha014a205112024-03-08 16:39:24 +00001879 num_test_sets = args_dict.get("num_test_sets", 0)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001880
1881 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1882 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001883 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001884 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001885 shape_info = (
1886 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1887 if "shape" in args_dict
1888 else ""
1889 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001890 logger.info(
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00001891 f"Skipping {opName}{shape_info} {gtu.DTYPE_ATTRIBUTES[dtype]['json']} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001892 )
1893 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001894 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001895 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001896 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001897
Jeremy Johnson30476252023-11-20 16:15:30 +00001898 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1899
evacha01ad8e1e22024-03-19 12:42:17 +00001900 elif dg_type == gtu.DataGenType.FULL_RANGE:
1901 tensor_size = gtu.product(shapeList[0])
1902 if tensor_size < gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1903 shape_info = " ({})".format(shapeList[0])
1904 logger.info(
1905 f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}"
1906 )
1907 continue
evacha014a205112024-03-08 16:39:24 +00001908 # Large enough tensor data size for full range, add full test
evacha01ad8e1e22024-03-19 12:42:17 +00001909 arg_str = f"{arg_str}_full" if arg_str else "full"
1910 gen_args_dict["tags"] = args_dict.get("tags", []) + [
1911 "non_finite_fp_data"
1912 ]
1913
evacha014a205112024-03-08 16:39:24 +00001914 elif dg_type == gtu.DataGenType.FP_SPECIAL:
1915 shapes_set = {tuple(x) for x in shapeList}
1916 if len(shapes_set) != 1:
1917 logger.info(
1918 f"Changing {opName} input shapes {shapes_set} - broadcasting incompatable with special test"
1919 )
1920 shapeList = [np.int32(np.broadcast_shapes(*shapeList))] * len(
1921 shapeList
1922 )
1923 arg_str = f"{arg_str}_fs" if arg_str else "fs"
1924
evacha01ad8e1e22024-03-19 12:42:17 +00001925 gen_args_dict["dg_type"] = dg_type
Jeremy Johnson30476252023-11-20 16:15:30 +00001926 if num_test_sets > 0:
1927 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001928 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
evacha01ad8e1e22024-03-19 12:42:17 +00001929 set_args_dict = gen_args_dict.copy()
evacha019c96eef2024-02-07 11:21:55 +00001930 set_args_dict["s"] = s
evacha019c96eef2024-02-07 11:21:55 +00001931 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001932 else:
1933 # Default is a single test
evacha01ad8e1e22024-03-19 12:42:17 +00001934 new_arg_list.append((arg_str, gen_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001935
1936 return new_arg_list
1937
1938 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001939 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001940 """A trivial argument generator for operators that don't take any
1941 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001942 arg_list = TosaArgGen._add_data_generators(
1943 testGen,
1944 opName,
evacha019c96eef2024-02-07 11:21:55 +00001945 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001946 dtype,
1947 [("", {})],
1948 error_name,
1949 )
1950 # Return list of tuples: (arg_str, args_dict)
1951 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001952
1953 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001954 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001955 """Pow operator needs different test sets to cover random numbers
1956 without creating NaNs or Infs"""
1957 arg_list = TosaArgGen._add_data_generators(
1958 testGen,
1959 opName,
evacha019c96eef2024-02-07 11:21:55 +00001960 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001961 dtype,
1962 [("", {"num_test_sets": 3})],
1963 error_name,
1964 )
1965 # Return list of tuples: (arg_str, args_dict)
1966 return arg_list
1967
1968 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001969 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001970 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001971 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001972 shape = shapeList[0]
1973
1974 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001975 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001976 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001977 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001978 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001979 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001980 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001981 # Create tests for each dimension
1982 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001983
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001984 opid = testGen.TOSA_OP_LIST[opName]["op"]
1985
1986 for a in axes:
1987 args_dict = {"axis": int(a)}
1988 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001989 output_shape = shape.copy()
1990 if error_name is None:
1991 # It only matters that we calculate the dot_products correctly
1992 # for non error_if tests as they should never be run
1993 output_shape[a] = 1
1994 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001995 args_dict["shape"] = shape
1996 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1997 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1998
1999 arg_list.append(("axis{}".format(a), args_dict))
2000
2001 arg_list = TosaArgGen._add_data_generators(
2002 testGen,
2003 opName,
evacha019c96eef2024-02-07 11:21:55 +00002004 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002005 dtype,
2006 arg_list,
2007 error_name,
2008 )
2009 # Return list of tuples: (arg_str, args_dict)
2010 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002011
2012 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002013 def _calculate_sparsity(num_tests, sparsity_factor):
2014 sparsity = num_tests // sparsity_factor + 1
2015 # If there are only a small number of tests, just select them all
2016 if sparsity < 13:
2017 sparsity = 1
2018 # To get a variety of parameter combinations sparsity should not be a
2019 # multiple of 2, 3 or 5
2020 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2021 sparsity += 1
2022 return sparsity
2023
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002024 # Maximum number of error_if variants to produce
Jeremy Johnson87460262024-03-25 09:46:02 +00002025 MAX_TESTS_ERROR_IFS = 3
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002026
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002027 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002028 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002029 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002030 arg_list = []
2031
Jeremy Johnson0c716862023-04-13 17:18:19 +01002032 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002033 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002034 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002035 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002036
Tai Lyf36f2562024-03-14 16:21:29 +00002037 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2038
2039 if error_name == ErrorIf.WrongAccumulatorType:
2040 accum_dtypes = (
2041 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2042 )
James Ward8b390432022-08-12 20:48:56 +01002043
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002044 # For op type checks
2045 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002046
2047 # Check the rank
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002048 rank = 5 if op["op"] == Op.CONV3D else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002049 if error_name != ErrorIf.WrongRank:
2050 assert len(ifm_shape) == rank
2051 assert len(filter_shape) == rank
2052
Jeremy Johnson0c716862023-04-13 17:18:19 +01002053 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002054 k_rank = rank - 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002055 k_pos = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002056 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002057 # compliance size - KS
2058 k_size = gtu.product(k_shape)
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002059 if not op["op"] == Op.DEPTHWISE_CONV2D:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002060 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002061
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002062 def get_conv_output_info(p, s, d, fix_up_padding=False):
2063 # Work out remainders and output dimensions with an
2064 # option to adjust paddings to create a valid operation
2065 nonlocal ifm_shape, k_shape, error_name, k_rank
2066 if fix_up_padding:
2067 p = list(p) # Make paddings editable
2068 outputs_no_stride = []
2069 remainders = []
2070 outputs = []
2071 for index in range(k_rank):
2072 pad_offset = index * 2
2073 fixed = False
2074 # Fix up pad values to produce valid conv2d
2075 while not fixed:
2076 # Output dimension without being adjusted for stride
2077 output_no_stride = (
2078 ifm_shape[index + 1]
2079 - 1
2080 + p[pad_offset]
2081 + p[pad_offset + 1]
2082 - (k_shape[index] - 1) * d[index]
2083 )
2084 # Tensor left over after applying striding
2085 remainder = output_no_stride % s[index]
2086 if not fix_up_padding:
2087 # Just want remainders and outputs
2088 break
2089 if output_no_stride <= 0:
2090 p[pad_offset + 1] += abs(output_no_stride) + 1
2091 continue
2092 if error_name == ErrorIf.ConvOutputShapeNonInteger:
2093 if remainder:
2094 # Conditions to trigger the test
2095 fixed = True
2096 else:
2097 p[pad_offset + 1] += 1
2098 else:
2099 if remainder:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002100 # Stride will be negative for StrideSmallerOne
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002101 assert remainder > 0 or (
2102 error_name == ErrorIf.StrideSmallerOne and remainder < 0
2103 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002104 p[pad_offset + 1] += abs(remainder)
2105 else:
2106 fixed = True
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002107 outputs_no_stride.append(output_no_stride)
2108 remainders.append(remainder)
2109 # Output dimension taking in to account stride
2110 outputs.append((output_no_stride // s[index]) + 1)
2111
2112 if fix_up_padding:
2113 p = tuple(p) # Make the paddings read-only
2114 assert min(outputs_no_stride) > 0, "Fix up did not work!"
2115 return p, remainders, outputs, outputs_no_stride
2116
2117 # Only fix up padding for conv2d and float types currently
2118 fix_up_padding = gtu.dtypeIsFloat(dtypes[0]) and op["op"] == Op.CONV2D
2119 # Allow any size of output dimension
2120 max_dim_size = None
2121 # Include all tests by default
2122 sparsity = 1
2123
2124 # Work out padding, strides and dilation ranges depending on
2125 # error and arguments
2126 if error_name in (
2127 ErrorIf.PadSmallerZero,
2128 ErrorIf.StrideSmallerOne,
2129 ErrorIf.DilationSmallerOne,
2130 ):
2131 # Use specific invalid value(s)
2132 if error_name == ErrorIf.PadSmallerZero:
2133 # Create negative paddings but with positive opposite paddings
2134 neg_pad = rng.choice(range(-5, 0))
2135 p_vals = [neg_pad, abs(neg_pad)]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002136 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002137 p_vals = [0, 0]
2138 if error_name == ErrorIf.StrideSmallerOne:
2139 # Can't use stride=0, as it is used to derive output shape, as a divisor
2140 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002141 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002142 s_vals = [1]
2143 if error_name == ErrorIf.DilationSmallerOne:
2144 d_vals = [rng.choice(range(-5, 1))]
2145 else:
2146 d_vals = [1]
2147 paddings = {tuple(p_vals) * k_rank}
2148 strides = {tuple(s_vals) * k_rank}
2149 dilations = {tuple(d_vals) * k_rank}
2150
2151 fix_up_padding = True # Need to fix up paddings to be valid
2152
2153 elif testGen.args.level8k and error_name is None:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002154 # Only test 8k levels boundaries
2155 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2156 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2157 bigPadding = bigKernel
2158
2159 dilation_shape = [1] * k_rank
2160 pad_shape = [0] * k_rank * 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002161 if op["op"] == Op.CONV3D:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002162 # Small stride apart from for big kernel (see below) to keep
2163 # tensor size/calculation small
2164 stride_shape = [1] * k_rank
2165 for idx in range(k_rank):
2166 pad_offset = idx * 2
2167 if k_shape[idx] == bigKernel:
2168 # Padding shape needs to account for tensor shape
2169 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2170 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2171 # Big stride to reduce output size
2172 stride_shape[idx] = bigKernel
2173 else:
2174 # Account for kernel size
2175 pad_shape[pad_offset] = k_shape[idx] - 1
2176 else:
2177 # Always have a large stride with extra padding and dilation to keep
2178 # tensor calculation reasonable
2179 stride_shape = [bigKernel] * k_rank
2180 for idx in range(k_rank):
2181 # Dilation shape must account for kernel size
2182 dilation_shape[idx] = bigKernel // k_shape[idx]
2183 # Padding shape needs to accommodate tensor/kernel & dilation
2184 pad_offset = idx * 2
2185 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2186 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2187
2188 strides = {tuple(stride_shape)}
2189 dilations = {tuple(dilation_shape)}
2190 paddings = {tuple(pad_shape)}
2191 # Create a limit for the output dimensions size
2192 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2193
2194 # Currently allow all combinations that are reasonable size
2195 sparsity = 1
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002196 else:
2197 # Generate comprehensive argument lists
2198 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
2199 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
2200 # Stride must be greater than 1 to force non-integer error
2201 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
2202 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
2203 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002204
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002205 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2206 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
2207
2208 if error_name is None and testGen.args.oversize:
2209 # add some oversize argument values
2210 if max(ifm_shape) < 64:
2211 bigPadding = 9
2212 paddings.update(
2213 {
2214 x
2215 for x in itertools.product(
2216 *([[0, bigPadding]] * (k_rank * 2))
2217 )
2218 }
2219 )
2220 bigStride = 8
2221 strides.update(
2222 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2223 )
2224 bigDilation = 7
2225 dilations.update(
2226 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2227 )
2228
2229 if error_name is None:
2230 # There are too many parameter combinations, so generate them sparsely,
2231 sparsity_factor = 120
2232 sparsity = TosaArgGen._calculate_sparsity(
2233 len(paddings) * len(strides) * len(dilations), sparsity_factor
2234 )
2235
2236 # Run through all the argument options creating valid test cases
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002237 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002238 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002239 for a in accum_dtypes:
2240 for s in sorted(list(strides)):
2241 for p in sorted(list(paddings)):
2242 for d in sorted(list(dilations)):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002243 if more_tests and (n % sparsity == 0):
2244 (
2245 p,
2246 remainders,
2247 outputs,
2248 outputs_no_stride,
2249 ) = get_conv_output_info(p, s, d, fix_up_padding)
2250 # Following is like checking each dimension N:
2251 # (ifm_shape[N+1] - 1 + p[N*2] + p[N*2+1]) > d[N] * (k_shape[N] - 1)
2252 if min(outputs_no_stride) <= 0:
2253 # Not a valid operation
2254 n += 1 # Increment count of tests
2255 continue
Tai Lyf36f2562024-03-14 16:21:29 +00002256
2257 if (
2258 # the parameters must produce integer exact output
2259 error_name != ErrorIf.ConvOutputShapeNonInteger
2260 and max(remainders) == 0
2261 ) or (
2262 error_name == ErrorIf.ConvOutputShapeNonInteger
2263 and max(remainders) > 0
2264 ):
2265 if (
2266 max_dim_size is not None
2267 and max(outputs) >= max_dim_size
2268 ):
2269 # Test will consume too much memory - skip it
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002270 logger.debug(
2271 "agConv: Convolution output too big - skipped"
2272 )
Tai Lyf36f2562024-03-14 16:21:29 +00002273 continue
2274
2275 # Compliance - number of dot product calculations
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002276 if op["op"] == Op.DEPTHWISE_CONV2D:
Tai Lyf36f2562024-03-14 16:21:29 +00002277 # N*OH*OW*C*M
2278 dots = gtu.product(
2279 (ifm_shape[0], *outputs, *filter_shape[2:])
2280 )
2281 else:
2282 # N*OH*OW*OC or N*OD*OH*OW*OC
2283 dots = gtu.product(
2284 (ifm_shape[0], *outputs, filter_shape[0])
2285 )
2286 args_dict = {
2287 "acc_type": a,
2288 "stride": s,
2289 "pad": p,
2290 "dilation": d,
2291 "kernel": k_shape,
2292 "ks": k_size,
2293 "dot_products": dots,
2294 "shape": ifm_shape,
2295 }
2296
2297 # Support for larger values than 9 needs different delimiter
2298 delim = "" if max(s + p + d) <= 9 else "x"
2299 arg_list.append(
2300 (
2301 "acc{}_st{}_pad{}_dilat{}".format(
2302 testGen.typeStr(a),
2303 delim.join([str(x) for x in s]),
2304 delim.join([str(x) for x in p]),
2305 delim.join([str(x) for x in d]),
2306 ),
2307 args_dict,
2308 )
2309 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002310 if (
2311 error_name
Jeremy Johnson87460262024-03-25 09:46:02 +00002312 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002313 ):
2314 # Found enough errors
2315 logger.debug(
2316 f"Skipping creating more conv error tests for {error_name}"
2317 )
2318 more_tests = False
Tai Lyf36f2562024-03-14 16:21:29 +00002319 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002320
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002321 arg_list = TosaArgGen._add_data_generators(
2322 testGen,
2323 opName,
evacha019c96eef2024-02-07 11:21:55 +00002324 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002325 dtypes[0],
2326 arg_list,
2327 error_name,
2328 )
2329 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002330 return arg_list
2331
2332 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002333 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002334
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002335 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002336 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002337
2338 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002339 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002340 elif error_name == ErrorIf.WrongInputType:
2341 # Pick some potentially correct output dtype if input type is incorrect
2342 accum_dtype = DType.INT32
2343 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002344 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002345
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002346 # Set up compliance info
2347 args_dict = {
2348 "acc_type": accum_dtype,
2349 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2350 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2351 "shape": shapeList[0],
2352 }
2353
2354 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2355
2356 arg_list = TosaArgGen._add_data_generators(
2357 testGen,
2358 opName,
evacha019c96eef2024-02-07 11:21:55 +00002359 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002360 input_dtype,
2361 arg_list,
2362 error_name,
2363 )
2364 # Return list of tuples: (arg_str, args_dict)
2365 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002366
2367 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002368 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002369 # Get valid accumulate type(s)
2370 if dtype == DType.INT8:
2371 accum_dtypes = [DType.INT32]
2372 elif dtype == DType.INT16:
2373 accum_dtypes = [DType.INT48]
2374 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002375 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002376 elif dtype == DType.BF16:
2377 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002378 elif dtype == DType.FP32:
2379 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002380 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2381 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002382 elif error_name is None:
2383 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2384
2385 if error_name == ErrorIf.WrongOutputType:
2386 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002387 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002388 elif error_name == ErrorIf.WrongInputType:
2389 # Pick some potentially correct output dtype if input type is incorrect
2390 accum_dtypes = [DType.INT32]
2391
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002392 # Set up compliance info
2393 args_dict = {
2394 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2395 # Set dot_products = N*H*W
2396 "dot_products": gtu.product(
2397 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2398 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002399 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002400 }
2401
2402 # Create arg tuple of string and dict
2403 arg_list = []
2404 for a in accum_dtypes:
2405 d = args_dict.copy()
2406 d["acc_type"] = a
2407 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002408
2409 arg_list = TosaArgGen._add_data_generators(
2410 testGen,
2411 opName,
evacha019c96eef2024-02-07 11:21:55 +00002412 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002413 dtype,
2414 arg_list,
2415 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002416 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002417 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002418 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002419
2420 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002421 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002422 arg_list = []
2423
Jeremy Johnson0c716862023-04-13 17:18:19 +01002424 if testGen.args.level8k and error_name is not None:
2425 # Don't produce negative large tests
2426 return arg_list
2427
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002428 ifm_shape = shapeList[0]
2429 filter_shape = shapeList[1]
2430
Tai Lyf36f2562024-03-14 16:21:29 +00002431 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2432
2433 if error_name == ErrorIf.WrongAccumulatorType:
2434 accum_dtypes = (
2435 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2436 )
James Ward8b390432022-08-12 20:48:56 +01002437
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002438 # Must be rank 4
2439 if error_name != ErrorIf.WrongRank:
2440 assert len(ifm_shape) == 4
2441 assert len(filter_shape) == 4
2442
Jeremy Johnson0c716862023-04-13 17:18:19 +01002443 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002444 # compliance size - KS
2445 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002446
Jeremy Johnson0c716862023-04-13 17:18:19 +01002447 if not testGen.args.level8k:
2448 # Generate comprehensive argument lists
2449 # - except for named errors, which use specific invalid value(s)
2450 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2451 if error_name == ErrorIf.PadLargerEqualKernel:
2452 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002453 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002454 else:
2455 p_vals = [
2456 x
2457 for x in range(
2458 smallest_padding_size, testGen.args.max_conv_padding + 1
2459 )
2460 ]
2461 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2462 if error_name == ErrorIf.StrideSmallerOne:
2463 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002464 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002465 else:
2466 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2467 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002468
Jeremy Johnson0c716862023-04-13 17:18:19 +01002469 if not error_name and testGen.args.oversize:
2470 # add some oversize argument values
2471 if max(ifm_shape) < 64:
2472 bigPadding = 9
2473 paddings.update(
2474 {
2475 x
2476 for x in itertools.product(
2477 *([[smallest_padding_size, bigPadding]] * 4)
2478 )
2479 }
2480 )
2481 bigStride = 8
2482 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2483
2484 # There are too many parameter combinations, so generate them sparsely,
2485 # very sparse for negative tests
2486 sparsity_factor = 2 if error_name else 10
2487 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2488 # If there are only a small number of tests, just select them all
2489 if sparsity < 13:
2490 sparsity = 1
2491 # To get a variety of parameter combinations sparsity should not be a
2492 # multiple of 2, 3 or 5
2493 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2494 sparsity += 1
2495 else:
2496 # Only test 8k levels boundaries
2497 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2498 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2499 bigPadding = bigKernel
2500
2501 pad_shape = [0] * (len(k_shape) * 2)
2502 stride_shape = [1] * len(k_shape)
2503 # The point at which input dimension combined with the stride will
2504 # create large output sizes!
2505 LARGE_SIZE = 2
2506 for idx in range(len(k_shape)):
2507 pad_offset = idx * 2
2508 if k_shape[idx] == bigKernel:
2509 # Set large stride
2510 stride_shape[idx] = bigKernel
2511 # Use negative output padding to reduce shape size
2512 pad_shape[pad_offset] = -(bigPadding - 1)
2513 if ifm_shape[idx + 1] > LARGE_SIZE:
2514 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2515 else:
2516 # The other dimension should be the bigKernel
2517 alt_idx = 1 - idx
2518 if (
2519 k_shape[alt_idx] == bigKernel
2520 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2521 ):
2522 # As the input is small, the large stride won't
2523 # affect the output so we can add some padding
2524 pad_shape[pad_offset + 1] = bigPadding
2525
2526 strides = {tuple(stride_shape)}
2527 paddings = {tuple(pad_shape)}
2528
2529 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002530 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002531
2532 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002533 for a in accum_dtypes:
2534 for s in sorted(list(strides)):
2535 for p in sorted(list(paddings)):
2536 if n % sparsity == 0:
2537 # Determine the output shape
2538 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2539 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2540 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002541
Tai Lyf36f2562024-03-14 16:21:29 +00002542 # N*OH*OW*OC
2543 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2544 args_dict = {
2545 "acc_type": a,
2546 "stride": s,
2547 "pad": p,
2548 "kernel": k_shape,
2549 "ks": k_size,
2550 "dot_products": dots,
2551 "shape": ifm_shape,
2552 "out_shape": os,
2553 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002554
Tai Lyf36f2562024-03-14 16:21:29 +00002555 # Support for larger values than 9 needs different delimiter
2556 delim = "" if max(s + p) <= 9 else "x"
2557 arg_list.append(
2558 (
2559 "acc{}_st{}_pad{}_os{}".format(
2560 testGen.typeStr(a),
2561 delim.join([str(x) for x in s]),
2562 delim.join([str(x) for x in p]),
2563 "x".join([str(x) for x in os]),
2564 ),
2565 args_dict,
2566 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002567 )
Tai Lyf36f2562024-03-14 16:21:29 +00002568 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002569
Jeremy Johnson95a67102024-01-10 14:16:39 +00002570 arg_list = TosaArgGen._add_data_generators(
2571 testGen,
2572 opName,
evacha019c96eef2024-02-07 11:21:55 +00002573 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002574 dtypes[0],
2575 arg_list,
2576 error_name,
2577 )
2578 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002579 return arg_list
2580
2581 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002582 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002583 rank = len(shapeList[0])
2584
Jeremy Johnson30a36842024-03-27 15:04:07 +00002585 if error_name is None and testGen.args.oversize:
2586 pad_values = [6, 7, 10, 13]
2587 elif error_name == ErrorIf.PadSmallerZero:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002588 pad_values = [x for x in range(-2, 0)]
Jeremy Johnson30a36842024-03-27 15:04:07 +00002589 else:
2590 # Exhaustively test combinations of padding on each side of each dimension
2591 # - the range of padding values is defined by pad_min and pad_max
2592 pad_min, pad_max = 0, 1
2593 pad_values = [x for x in range(pad_min, pad_max + 1)]
2594
2595 # Calculate pad combinations
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002596 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2597 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2598
2599 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002600 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002601 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002602 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002603 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002604 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002605 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002606 assert error_name == ErrorIf.WrongInputType
2607 pad_const_int = 0
2608 pad_const_fp = 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002609
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002610 list_shape_pad_values = list(shape_pad_values)
2611 # If we are producing tests for rank 6 or greater use sparsity
2612 if len(list_shape_pad_values) > 1024:
2613 sparsity_factor = 2 if error_name else 120
2614 sparsity = TosaArgGen._calculate_sparsity(
2615 len(list_shape_pad_values), sparsity_factor
2616 )
2617 else:
2618 sparsity = 1
2619
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002620 # Build arg list
2621 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002622 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002623 paddings = list(paddings)
2624 args_valid = True
2625
2626 if error_name == ErrorIf.PadSmallerZero:
2627 # Prevent negative output shapes while ensuring still testing for negative padding
2628 for i in range(rank):
2629 dim_after_padding = (
2630 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2631 )
2632 if dim_after_padding < 1:
2633 paddings[i] = (0, 0)
2634 if all([p > -1 for p in paddings[i]]):
2635 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002636 if args_valid and n % sparsity == 0:
Jeremy Johnson30a36842024-03-27 15:04:07 +00002637 # Work out name
2638 pad_list = []
James Ward8b390432022-08-12 20:48:56 +01002639 for r in range(rank):
Jeremy Johnson30a36842024-03-27 15:04:07 +00002640 pad_list.extend(paddings[r])
2641
2642 delim = "" if max(pad_list) <= 9 else "x"
2643 name = "pad{}".format(delim.join([str(x) for x in pad_list]))
2644
2645 args_dict = {
2646 "pad": np.array(paddings),
2647 "pad_const_int": pad_const_int,
2648 "pad_const_fp": pad_const_fp,
2649 }
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002650 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002651
2652 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002653 logger.debug(
2654 f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
2655 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002656
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002657 arg_list = TosaArgGen._add_data_generators(
2658 testGen,
2659 opName,
evacha019c96eef2024-02-07 11:21:55 +00002660 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002661 dtype,
2662 arg_list,
2663 error_name,
2664 )
2665
2666 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002667 return arg_list
2668
2669 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002670 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002671 arg_list = []
2672
2673 shape = shapeList[0]
2674 if error_name != ErrorIf.WrongRank:
2675 assert len(shape) == 4
2676
Jeremy Johnson0c716862023-04-13 17:18:19 +01002677 test_level8k = testGen.args.level8k and error_name is None
2678
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002679 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002680 startKernel = 2
2681 startPad = 0
2682 if not test_level8k:
2683 # Generate comprehensive argument lists
2684 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2685 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2686 # Stride must be greater than 1 to force non-integer error
2687 s_vals = [
2688 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2689 ]
2690 strides = {x for x in itertools.product(*([s_vals] * 2))}
2691 k_vals = [
2692 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2693 ]
2694 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2695 max_dim_size = None
2696 else:
2697 # Only test 8k levels
2698 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2699 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2700 strides = {(1, bigStride), (bigStride, 4)}
2701 kernels = {(1, bigKernel), (bigKernel, 3)}
2702 paddings = set()
2703 for s in sorted(list(strides)):
2704 for k in sorted(list(kernels)):
2705 padding = []
2706 for idx in range(len(k)):
2707 total_padding = s[idx] - shape[idx + 1] + k[idx]
2708 while total_padding < 0:
2709 # Must meet: shape + padding > kernel
2710 total_padding += s[idx]
2711 if total_padding < k[idx]:
2712 padding.extend([0, total_padding])
2713 else:
2714 # Note this may produce padding >= k[idx] which is not
2715 # allowed - but will be ignored in the creation loop below
2716 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2717 paddings.add(tuple(padding))
2718 # Create a limit for the output dimensions size
2719 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002720
James Ward8b390432022-08-12 20:48:56 +01002721 if opName == "max_pool2d":
2722 accum_dtypes = [None] # max_pool has no accumulate dtype
2723 elif dtype == DType.INT8 or dtype == DType.INT16:
2724 accum_dtypes = [DType.INT32]
2725 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002726 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002727 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002728 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002729 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2730 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002731 elif error_name is None:
2732 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2733 else:
2734 # Set to something for the ErrorIf case which has
2735 # incorrect input data-type
2736 accum_dtypes = [DType.INT32]
2737
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002738 if error_name == ErrorIf.WrongAccumulatorType:
2739 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2740
Jeremy Johnson0c716862023-04-13 17:18:19 +01002741 if not test_level8k:
2742 if testGen.args.oversize:
2743 # add some oversize argument values
2744 bigStride = 7
2745 bigKernel = 9
2746 strides.update(
2747 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002748 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002749 kernels.update(
2750 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2751 )
2752 if max(shape) < 64:
2753 # padding must be less than the kernel size
2754 bigPadding = bigKernel - 1
2755 paddings.update(
2756 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2757 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002758
Jeremy Johnson87460262024-03-25 09:46:02 +00002759 if error_name:
2760 # Cycle through all error_if tests but we only keep the first few
2761 sparsity = 1
2762 else:
2763 # There are too many parameter combinations, so generate them sparsely
2764 sparsity_factor = 500
2765 sparsity = (
2766 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2767 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002768 else:
2769 # We have already limited test output combinations for 8k tests
2770 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002771
James Ward8b390432022-08-12 20:48:56 +01002772 arg_str = (
2773 "acc{}_st{}_kern{}_pad{}"
2774 if accum_dtypes[0] is not None
2775 else "st{}_kern{}_pad{}"
2776 )
2777
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002778 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002779 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002780 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002781
2782 # Support for larger values than 9 needs different delimiter
2783 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002784 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002785 delim.join([str(x) for x in stride]),
2786 delim.join([str(x) for x in kern]),
2787 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002788 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002789 args_dict = {
2790 "stride": stride,
2791 "pad": pad,
2792 "kernel": kern,
2793 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002794 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002795 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2796 }
James Ward8b390432022-08-12 20:48:56 +01002797
2798 if accum is not None:
2799 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002800 args_dict["acc_type"] = accum
2801 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002802
Jeremy Johnson87460262024-03-25 09:46:02 +00002803 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002804 n = 0
James Ward8b390432022-08-12 20:48:56 +01002805 for a in accum_dtypes:
2806 for s in sorted(list(strides)):
2807 for p in sorted(list(paddings)):
2808 for k in sorted(list(kernels)):
2809 if error_name in [
2810 ErrorIf.StrideSmallerOne,
2811 ErrorIf.KernelSmallerOne,
2812 ErrorIf.PadSmallerZero,
2813 ErrorIf.PadLargerEqualKernel,
2814 ]:
2815 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002816 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002817 )
James Ward8b390432022-08-12 20:48:56 +01002818 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002819 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002820 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002821 )
James Ward8b390432022-08-12 20:48:56 +01002822 elif (
Jeremy Johnson87460262024-03-25 09:46:02 +00002823 more_tests
2824 and n % sparsity == 0
James Ward8b390432022-08-12 20:48:56 +01002825 # padding must not exceed the kernel size
2826 and p[0] < k[0]
2827 and p[1] < k[0]
2828 and p[2] < k[1]
2829 and p[3] < k[1]
2830 # the padded shape must exceed the kernel size
2831 and (shape[1] + p[0] + p[1]) > k[0]
2832 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002833 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002834 partial_h = shape[1] + p[0] + p[1] - k[0]
2835 partial_w = shape[2] + p[2] + p[3] - k[1]
2836 remainder_h = partial_h % s[0]
2837 remainder_w = partial_w % s[1]
2838 output_h = partial_h // s[0] + 1
2839 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002840 logger.debug(
2841 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2842 )
James Ward8b390432022-08-12 20:48:56 +01002843 if (
2844 # the parameters must produce integer exact output
2845 error_name != ErrorIf.PoolingOutputShapeNonInteger
2846 and remainder_h == 0
2847 and remainder_w == 0
2848 ) or (
2849 error_name == ErrorIf.PoolingOutputShapeNonInteger
2850 and (remainder_h != 0 or remainder_w != 0)
2851 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002852 if (
2853 max_dim_size is not None
2854 and max(output_h, output_w) > max_dim_size
2855 ):
2856 # Test will consume too much memory - skip it
2857 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002858 # Dot products = N*OH*OW*C
2859 dp = gtu.product(
2860 (shape[0], output_h, output_w, shape[3])
2861 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002862 arg_list.append(
2863 get_arg_list_element(a, s, p, k, dp, shape)
2864 )
Jeremy Johnson87460262024-03-25 09:46:02 +00002865 if (
2866 error_name
2867 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
2868 ):
2869 # Found enough errors
2870 logger.debug(
2871 f"Skipping creating more pooling error tests for {error_name}"
2872 )
2873 more_tests = False
2874
James Ward8b390432022-08-12 20:48:56 +01002875 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002876
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002877 # Now add data generator types
2878 arg_list = TosaArgGen._add_data_generators(
2879 testGen,
2880 opName,
evacha019c96eef2024-02-07 11:21:55 +00002881 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002882 dtype,
2883 arg_list,
2884 error_name,
2885 )
2886
2887 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002888 return arg_list
2889
2890 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002891 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002892 arg_list = []
2893
2894 # Enumerate the output types here
2895 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002896 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002897 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002898 dtypeList = [
2899 DType.BOOL,
2900 DType.INT16,
2901 DType.INT32,
2902 DType.FP16,
2903 DType.BF16,
2904 DType.FP32,
2905 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002906 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002907 dtypeList = [
2908 DType.BOOL,
2909 DType.INT8,
2910 DType.INT32,
2911 DType.FP16,
2912 DType.BF16,
2913 DType.FP32,
2914 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002915 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002916 dtypeList = [
2917 DType.BOOL,
2918 DType.INT8,
2919 DType.INT16,
2920 DType.FP16,
2921 DType.BF16,
2922 DType.FP32,
2923 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002924 elif inDtype == DType.BOOL:
2925 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002926 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002927 dtypeList = [
2928 DType.INT8,
2929 DType.INT16,
2930 DType.INT32,
2931 DType.FP32,
2932 DType.FP8E4M3,
2933 DType.FP8E5M2,
2934 ]
James Ward24dbc422022-10-19 12:20:31 +01002935 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002936 dtypeList = [
2937 DType.INT8,
2938 DType.INT16,
2939 DType.INT32,
2940 DType.FP32,
2941 DType.FP8E4M3,
2942 DType.FP8E5M2,
2943 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002944 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002945 dtypeList = [
2946 DType.INT8,
2947 DType.INT16,
2948 DType.INT32,
2949 DType.FP16,
2950 DType.BF16,
2951 DType.FP8E4M3,
2952 DType.FP8E5M2,
2953 ]
2954 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2955 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002956 elif error_name == ErrorIf.WrongInputType:
2957 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002958 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002959 else:
2960 raise Exception("Unexpected input dtype: {}".format(inDtype))
2961
2962 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002963 arg_list.append(
2964 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2965 )
2966
2967 # Now add data generator types
2968 arg_list = TosaArgGen._add_data_generators(
2969 testGen,
2970 opName,
evacha019c96eef2024-02-07 11:21:55 +00002971 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002972 dtype,
2973 arg_list,
2974 error_name,
2975 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002976
2977 return arg_list
2978
2979 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002980 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002981 arg_list = []
2982
2983 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002984 for outDtype in [
2985 DType.UINT8,
2986 DType.INT8,
2987 DType.INT16,
2988 DType.INT32,
2989 DType.UINT16,
2990 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002991 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002992 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002993 and error_name == ErrorIf.OutputZeroPointNotZero
2994 ):
2995 continue
2996 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002997 outDtype != DType.UINT16
2998 and error_name == ErrorIf.U16OutputZeroPointNotValid
2999 ) or (
3000 inDtype != DType.UINT16
3001 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003002 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003003 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003004 continue
3005 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003006 inDtype == DType.UINT8
3007 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003008 and error_name != ErrorIf.WrongOutputType
3009 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003010 # The only output dtypes for UINT8 are INT8/INT16, skip all others
3011 continue
3012 if (
3013 inDtype not in [DType.INT8, DType.INT16]
3014 and outDtype == DType.UINT8
3015 and error_name != ErrorIf.WrongOutputType
3016 ):
3017 # The only input dtypes for UINT8 are INT8/INT16, skip all others
3018 continue
3019 if (
3020 inDtype == DType.UINT16
3021 and outDtype != DType.INT16
3022 and error_name != ErrorIf.WrongOutputType
3023 ):
3024 # The only output dtype for UINT16 is INT16, skip all others
3025 continue
3026 if (
3027 inDtype != DType.INT16
3028 and outDtype == DType.UINT16
3029 and error_name != ErrorIf.WrongOutputType
3030 ):
3031 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003032 continue
3033 if (
3034 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003035 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003036 ):
3037 continue
3038
3039 for scale32 in [False, True]:
3040 if error_name == ErrorIf.ScaleTrue and not scale32:
3041 continue
3042 elif error_name == ErrorIf.ScaleNotTrue and scale32:
3043 continue
3044 for double_round in [False, True]:
3045 if error_name == ErrorIf.ScaleNotTrue and not double_round:
3046 continue
Jeremy Johnson18a379d2024-03-28 15:53:21 +00003047 # Per_channel is only valid with rank > 0
3048 pc_options = (False, True) if len(shapeList[0]) > 0 else (False,)
3049 for per_channel in pc_options:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003050
3051 if (
3052 inDtype == DType.INT48
3053 and scale32
3054 and error_name != ErrorIf.ScaleTrue
3055 ):
3056 # Illegal condition. Must be scale32=False
3057 continue
3058 if (
3059 double_round
3060 and not scale32
3061 and error_name != ErrorIf.ScaleNotTrue
3062 ):
3063 # Illegal condition. ERROR_IF(!scale32 && double_round)
3064 continue
3065
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003066 if per_channel:
3067 nc = shapeList[0][-1]
3068 else:
3069 nc = 1
3070
3071 in_type_width = gtu.dtypeWidth(inDtype)
3072 out_type_width = gtu.dtypeWidth(outDtype)
3073
3074 # Calculate scale based on:
3075 # scale = a *(2^output_width)/(2^input_width))
3076
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003077 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003078 scale_arr = a * np.float32(
3079 (1 << out_type_width) / (1 << in_type_width)
3080 )
3081
3082 if scale32:
3083 # Cap the scaling at 2^31 - 1 for scale32
3084 scale_arr = np.clip(
3085 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
3086 )
3087 else:
3088 # Cap the scaling at 2^15 - 1 for scale16
3089 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3090
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003091 logger.debug(
3092 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
3093 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003094
3095 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3096 shift_arr = np.int32(np.zeros(shape=[nc]))
3097 for i in range(nc):
3098 (
3099 multiplier_arr[i],
3100 shift_arr[i],
3101 ) = TosaQuantGen.computeMultiplierAndShift(
3102 scale_arr[i], scale32
3103 )
3104
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003105 arg_list.append(
3106 (
3107 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01003108 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003109 int(scale32),
3110 int(double_round),
3111 int(per_channel),
3112 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00003113 {
3114 "output_dtype": outDtype,
3115 "scale": scale32,
3116 "double_round": double_round,
3117 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003118 "multiplier": multiplier_arr,
3119 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003120 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003121 )
3122 )
3123
Jeremy Johnson587cc842024-02-08 11:45:44 +00003124 arg_list = TosaArgGen._add_data_generators(
3125 testGen,
3126 opName,
evacha019c96eef2024-02-07 11:21:55 +00003127 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003128 inDtype,
3129 arg_list,
3130 error_name,
3131 )
3132 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003133 return arg_list
3134
3135 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003136 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003137 arg_list = []
3138
3139 if dtype is DType.INT32:
3140 for p in range(testGen.args.num_rand_permutations):
3141
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003142 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003143 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003144 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003145 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003146
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003147 arg_list = TosaArgGen._add_data_generators(
3148 testGen,
3149 opName,
evacha019c96eef2024-02-07 11:21:55 +00003150 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003151 dtype,
3152 arg_list,
3153 error_name,
3154 )
3155 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003156 return arg_list
3157
3158 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003159 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003160 arg_list = []
3161
Jeremy Johnson587cc842024-02-08 11:45:44 +00003162 for round in (True, False):
3163 args_dict = {
3164 "round": round,
3165 }
3166 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003167
Jeremy Johnson587cc842024-02-08 11:45:44 +00003168 arg_list = TosaArgGen._add_data_generators(
3169 testGen,
3170 opName,
evacha019c96eef2024-02-07 11:21:55 +00003171 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003172 dtype,
3173 arg_list,
3174 error_name,
3175 )
3176 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003177 return arg_list
3178
Luke Hutton57287132023-02-06 14:54:18 +00003179 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003180 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003181 arg_list = []
3182
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003183 shape = shapeList[0]
3184 dot_products = gtu.product(shape)
3185 ks = 2 * shape[1] * shape[2] # 2*H*W
3186 for inverse in (True, False):
3187 args_dict = {
3188 "dot_products": dot_products,
3189 "shape": shape,
3190 "ks": ks,
3191 "acc_type": dtype,
3192 "inverse": inverse,
3193 }
3194 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003195
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003196 arg_list = TosaArgGen._add_data_generators(
3197 testGen,
3198 opName,
evacha019c96eef2024-02-07 11:21:55 +00003199 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003200 dtype,
3201 arg_list,
3202 error_name,
3203 )
3204 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003205 return arg_list
3206
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003207 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003208 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003209 arg_list = []
3210
3211 shape = shapeList[0]
3212 dot_products = gtu.product(shape)
3213 ks = shape[1] * shape[2] # H*W
3214 args_dict = {
3215 "dot_products": dot_products,
3216 "shape": shape,
3217 "ks": ks,
3218 "acc_type": dtype,
3219 }
3220 arg_list.append(("", args_dict))
3221
3222 arg_list = TosaArgGen._add_data_generators(
3223 testGen,
3224 opName,
evacha019c96eef2024-02-07 11:21:55 +00003225 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003226 dtype,
3227 arg_list,
3228 error_name,
3229 )
3230 # Return list of tuples: (arg_str, args_dict)
3231 return arg_list
3232
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003233 # Helper function for reshape. Gets some factors of a larger number.
3234 @staticmethod
3235 def getFactors(val, start=1):
3236 factors = []
3237
3238 for i in range(start, int(np.sqrt(val)) + 1):
3239 if (val % i) == 0:
3240 factors.append(i)
3241
3242 return factors
3243
3244 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003245 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003246 arg_list = []
3247
3248 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003249 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003250 factors = TosaArgGen.getFactors(totalElements)
3251
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003252 # Find new shapes up to the number of permutations asked for
3253 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003254 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003255 # Rank from 1 to MAX_TENSOR_RANK
3256 newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003257 if len(factors) < newRank:
3258 continue
3259
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003260 # escape_counter limits the generation of new shapes to a reasonable time
3261 for escape_counter in range(100):
3262
3263 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003264 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003265 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003266 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003267 for i in range(1, newRank):
3268 # pick rank-1 factors
3269 newShape.append(shuffledFactors[0])
3270 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003271 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003272 TosaArgGen.getFactors(remainingElements)
3273 )
3274 newShape.append(remainingElements)
3275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003276 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003277 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003278 for name, args_dict in arg_list:
3279 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003280 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003281 break
3282
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003283 if not duplicate:
3284 outShape = "x".join([str(x) for x in newShape])
3285 arg_list.append(
3286 (
3287 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3288 {"new_shape": newShape},
3289 )
3290 )
3291 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003292 break
3293
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003294 # Now add data generator types
3295 arg_list = TosaArgGen._add_data_generators(
3296 testGen,
3297 opName,
evacha019c96eef2024-02-07 11:21:55 +00003298 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003299 dtype,
3300 arg_list,
3301 error_name,
3302 )
3303
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003304 return arg_list
3305
3306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003307 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003308 arg_list = []
3309
3310 ifm_shape = shapeList[0]
3311
3312 if error_name == ErrorIf.IndexOutsideBounds:
3313 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3314 incorrect_small_index = range(-len(ifm_shape), 0)
3315 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3316 permutations.extend(
3317 [p for p in itertools.permutations(incorrect_small_index)]
3318 )
3319 elif error_name == ErrorIf.IndexUsedTwice:
3320 # Create list with a duplicated index
3321 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003322 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003323 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3324 permutations = [p for p in itertools.permutations(perm_range)]
3325
3326 else:
3327 # Get all permutations
3328 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3329
3330 # Limit to possible permutations from shape dimension or argument setting
3331 limit = min(len(permutations), testGen.args.num_rand_permutations)
3332
3333 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003334 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003335
3336 # Create list of required amount of permutations
3337 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003338 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003339 for p in range(limit)
3340 ]
evacha0198477222024-01-26 12:25:32 +00003341 # Now add data generator types
3342 arg_list = TosaArgGen._add_data_generators(
3343 testGen,
3344 opName,
evacha019c96eef2024-02-07 11:21:55 +00003345 shapeList,
evacha0198477222024-01-26 12:25:32 +00003346 dtype,
3347 arg_list,
3348 error_name,
3349 )
3350 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003351 return arg_list
3352
3353 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003354 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003355 arg_list = []
3356
3357 ifm_shape = shapeList[0]
3358 rank = len(ifm_shape)
3359
3360 for p in range(testGen.args.num_rand_permutations):
3361 start = []
3362 size = []
3363
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003364 for i in range(rank):
3365 if ifm_shape[i] > 1:
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003366 # Start from 0 to dimension size - 1 to leave room for slice of 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003367 start.append(rng.randInt(0, ifm_shape[i]))
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003368 # Size from 1 up to rest of room (dimension size - start)
3369 size.append(rng.randInt(1, ifm_shape[i] + 1 - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003370
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003371 # Should never hit an invalid slice size
3372 assert size[i] > 0 and (size[i] + start[i]) <= ifm_shape[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003373 else:
3374 start.append(0)
3375 size.append(1)
3376
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003377 # If ERROR_IF test required then incorrect start, size will be returned
3378 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
3379 rng, error_name, ifm_shape, start, size
3380 )
3381 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3382
evacha017f7d4252024-01-24 12:08:09 +00003383 # Now add data generator types
3384 arg_list = TosaArgGen._add_data_generators(
3385 testGen,
3386 opName,
evacha019c96eef2024-02-07 11:21:55 +00003387 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003388 dtype,
3389 arg_list,
3390 error_name,
3391 )
3392 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003393 return arg_list
3394
3395 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003396 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003397 arg_list = []
3398
3399 ifm_shape = shapeList[0]
3400 rank = len(ifm_shape)
3401
3402 for p in range(testGen.args.num_rand_permutations):
3403
3404 # Pick a few random, but small multiple values
3405 # because otherwise this has a tendency to generate
3406 # enormous tensors
3407 multiples = []
3408 for i in range(rank):
3409 if ifm_shape[i] > 1000:
3410 # Multiple of 1 if ifm_shape dimension is large to reduce
3411 # tensor size
3412 multiples.append(1)
3413 elif max(ifm_shape) > 1000:
3414 multiples.append(2)
3415 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003416 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003417 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003418
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003419 # Now add data generator types
3420 arg_list = TosaArgGen._add_data_generators(
3421 testGen,
3422 opName,
evacha019c96eef2024-02-07 11:21:55 +00003423 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003424 dtype,
3425 arg_list,
3426 error_name,
3427 )
3428 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003429 return arg_list
3430
3431 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003432 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003433 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003434 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003435
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003436 def get_aspect_ratio_resize_params():
3437 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003438 aspect_ratio = rng.choice(common_aspect_ratios)
3439 invert = rng.choice((False, True))
3440 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003441
3442 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3443 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3444 scale_y_d = scale_x_d = 1
3445 offset_x = offset_y = 0
3446
3447 if letterbox:
3448 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003449 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003450 border_x = 0
3451 else:
3452 # Pillarboxing
3453 border_y = 0
3454 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003455 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003456
3457 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3458 offset = (offset_y, offset_x)
3459 border = (border_y, border_x)
3460
3461 return scale, offset, border
3462
3463 def get_upscale_downscale_params():
3464 valid_params = False
3465 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003466 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003467
3468 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003469 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003470
3471 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003472 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003473 scale_x_d = scale_y_d = 1
3474 scale_x_n = scale_y_n = (
3475 1 << shift if origin_sampling else 2 << shift
3476 )
3477 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3478 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3479 else:
3480 scale_x_n = 1
3481 scale_y_n = 1
3482
3483 # Return list of valid scale_*_d values (max value 4) given input dim shape
3484 def get_valid_denom(ifm_dim):
3485 return [x for x in range(1, 5) if ifm_dim % x == 1]
3486
3487 # Generate list of valid downscale values and choose one randomly
3488 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3489 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3490
3491 if not valid_scale_y_ds and not valid_scale_x_ds:
3492 # Bad parameters, skip
3493 continue
3494
3495 if not valid_scale_y_ds:
3496 scale_y_d = 1
3497 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003498 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003499
3500 if not valid_scale_x_ds:
3501 scale_x_d = 1
3502 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003503 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003504
3505 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003506 offset_y = rng.randInt(0, 16 * scale_y_n)
3507 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003508 valid_params = True
3509
3510 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3511 offset = (offset_y, offset_x)
3512 border = (border_y, border_x)
3513 return scale, offset, border
3514
3515 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003516 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3517 scale = scale_n / scale_d
3518 if scale > max_scale:
3519 factor = scale / max_scale
3520 new_scale_d = math.ceil(scale_d * factor)
3521 assert scale_n / new_scale_d <= max_scale
3522 scale_d = new_scale_d
3523 return scale_d
3524
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003525 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003526 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3527 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003528
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003529 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3530 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003531
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003532 scale_y_d = fix_scale_to_max_scale(
3533 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3534 )
3535 scale_x_d = fix_scale_to_max_scale(
3536 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3537 )
3538
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003539 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003540 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3541 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3542 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3543 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003544
3545 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3546 offset = (offset_y, offset_x)
3547 border = (border_y, border_x)
3548 return scale, offset, border
3549
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003550 def get_level_8k_params():
3551 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003552 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003553 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3554 )
3555 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3556 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003557 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003558 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003559 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003560 if switch:
3561 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3562 else:
3563 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3564
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003565 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3566 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003567 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003568 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3569 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003570 border = (border_y, border_x)
3571 return scale, offset, border
3572
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003573 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003574 # Exclude illegal {mode, type} configurations. Pick legal output types
3575 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3576 outputDTypeList = [DType.INT8]
3577 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3578 outputDTypeList = [DType.INT16]
3579 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3580 outputDTypeList = [DType.INT32]
3581 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3582 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003583 elif dtype == DType.FP16:
3584 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003585 elif dtype == DType.BF16:
3586 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003587 elif dtype == DType.FP32:
3588 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003589 elif dtype == DType.FP8E4M3:
3590 outputDTypeList = [DType.FP8E4M3]
3591 elif dtype == DType.FP8E5M2:
3592 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003593 elif error_name == ErrorIf.WrongInputType:
3594 # If an incorrect input type is used then we set a 'correct'
3595 # output type to avoid other errors
3596 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3597 else:
3598 continue
3599
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003600 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3601
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003602 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003603 perm = 0
3604 while perm < testGen.args.num_rand_permutations:
3605 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003606 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003607 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003608 (
3609 get_rand_params,
3610 get_upscale_downscale_params,
3611 get_aspect_ratio_resize_params,
3612 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003613 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003614 scale, offset, border = _rnd_param_fn()
3615 else:
3616 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003617
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003618 # Expand params for bounds-checking
3619 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3620 (offset_y, offset_x) = offset
3621 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003622
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003623 # Make sure output dimensions OH and OW are integers
3624 partial_output_y = (
3625 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3626 )
3627 partial_output_x = (
3628 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3629 )
3630 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003631 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003632 if (
3633 partial_output_y % scale_y_d == 0
3634 and partial_output_x % scale_x_d == 0
3635 ):
3636 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003637 if perm > 0:
3638 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003640 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003641 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003642 while partial_output_y % scale_y_d != 0:
3643 scale_y_d -= 1
3644 while partial_output_x % scale_x_d != 0:
3645 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003646 # Make sure we are still within max scaling
3647 if (
3648 scale_y_n / scale_y_d
3649 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3650 scale_x_n / scale_x_d
3651 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3652 # Skip the test as it is using too large a scaling factor
3653 if perm > 0:
3654 perm += 1
3655 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003656
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003657 output_y = partial_output_y // scale_y_d + 1
3658 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003659
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003660 if (
3661 output_y >= testGen.args.max_resize_output_dim
3662 or output_x >= testGen.args.max_resize_output_dim
3663 ) and error_name is None:
3664 # Skip positive test if output dim will be too high
3665 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003666 if not testGen.args.level8k or perm > 0:
3667 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003668 continue
3669
3670 if (
3671 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003672 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003673 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003674 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003675 ):
3676 # Output dimensions out of scope
3677 if error_name is not None and perm > 0:
3678 # As long as we have one ERROR_IF test, don't worry
3679 # about creating all the other permutations
3680 perm += 1
3681 continue
3682
3683 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3684 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003685 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003686 and output_y - scale_y_d < 1
3687 )
3688 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003689 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003690 and output_x - scale_x_d < 1
3691 )
3692 ):
3693 # Can't create a negative test with these params as it
3694 # will create invalid output size
3695 if perm > 0:
3696 perm += 1
3697 continue
3698
3699 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3700 offset = [offset_y, offset_x]
3701 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003702
3703 # Common for all data types
3704 if error_name is not None:
3705 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003706 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003707 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003708 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003709 outputDTypeNew,
3710 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003711 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003712 error_name,
3713 mode,
3714 dtype,
3715 shapeList,
3716 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003717 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003718 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003719 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003720 )
3721 else:
3722 outputDTypeNew = outputDType
3723
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003724 arg_to_append = (
3725 arg_str.format(
3726 "N" if mode == ResizeMode.NEAREST else "B",
3727 testGen.typeStr(outputDTypeNew),
3728 scale[0],
3729 scale[1],
3730 scale[2],
3731 scale[3],
3732 offset[0],
3733 offset[1],
3734 border[0],
3735 border[1],
3736 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003737 {
3738 "mode": mode,
3739 "scale": scale,
3740 "offset": offset,
3741 "border": border,
3742 "output_dtype": outputDTypeNew,
3743 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003744 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003745 if arg_to_append in arg_list:
3746 # Skip already generated test params
3747 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003748
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003749 # Valid permutation
3750 perm += 1
3751 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003752
3753 # Now add data generator types
3754 arg_list = TosaArgGen._add_data_generators(
3755 testGen,
3756 opName,
evacha019c96eef2024-02-07 11:21:55 +00003757 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003758 dtype,
3759 arg_list,
3760 error_name,
3761 )
3762 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003763 return arg_list
3764
3765 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003766 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 arg_list = []
3768
3769 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003770 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003771 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003772 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003773 # Make sure all slopes are within REQUIRE min/max 16-bit int
3774 for idx in range(len(table) - 1):
3775 slope = table[idx + 1] - table[idx]
3776 # Alter the next table entry to force the slope to be ok
3777 if slope > 32767:
3778 table[idx + 1] -= slope - 32767
3779 if slope < -32768:
3780 table[idx + 1] -= slope + 32768
3781 slope = table[idx + 1] - table[idx]
3782 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003783 arg_list.append(
3784 (
3785 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003786 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003787 )
3788 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003789 # Now add data generator types
3790 arg_list = TosaArgGen._add_data_generators(
3791 testGen,
3792 opName,
evacha019c96eef2024-02-07 11:21:55 +00003793 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003794 dtype,
3795 arg_list,
3796 error_name,
3797 )
3798 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003799 return arg_list
3800
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003801 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003802 # CondIf generates the condition values here.
3803 # Convert to tensors in the build function, along with the
3804 # then and else blocks
3805 arg_list = []
3806
3807 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003808 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003809
Jeremy Johnson587cc842024-02-08 11:45:44 +00003810 # Now add data generator types
3811 arg_list = TosaArgGen._add_data_generators(
3812 testGen,
3813 opName,
evacha019c96eef2024-02-07 11:21:55 +00003814 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003815 dtype,
3816 arg_list,
3817 error_name,
3818 )
3819 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003820 return arg_list
3821
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003822 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003823 # While loop: 0 iterations, 1, more than 1
3824 arg_list = []
3825
Jeremy Johnson587cc842024-02-08 11:45:44 +00003826 for iterations in [0, 1, 4]:
3827 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003828
Jeremy Johnson587cc842024-02-08 11:45:44 +00003829 # Now add data generator types
3830 arg_list = TosaArgGen._add_data_generators(
3831 testGen,
3832 opName,
evacha019c96eef2024-02-07 11:21:55 +00003833 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003834 dtype,
3835 arg_list,
3836 error_name,
3837 )
3838 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003839 return arg_list