blob: 41b0936161eced158950f24954a7c515c443f950 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000262 if rank == 0:
263 # No broadcasting possible for rank 0
264 return [[]] * num_shapes
265
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100266 shape = testGen.makeShape(rng, rank)
evacha014a205112024-03-08 16:39:24 +0000267 # Do not broadcast for some tests
268 if error_name is None and rng.randInt(high=100) < 10:
269 return [shape] * num_shapes
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100270 shape_list = []
271
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000272 # Choose any one of the inputs to broadcast
273 # Note for ERRORS: Simplifies OutputShaper code if we don't change first shape
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100274 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
275 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000276
Jeremy Johnson0a042992024-02-28 13:20:05 +0000277 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100278 shape_bcast = shape.copy()
279
Jerry Ge135c9552023-05-23 20:59:32 +0000280 # To test broadcasting, the chosen fuzz index dimension should not be 1
281 if shape_bcast[fuzz_idx] == 1:
282 shape_bcast[fuzz_idx] += 1
283
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 # If the chosen input, pick a random index to broadcast
285 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000286 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100288 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100289 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100290 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 )
292 if rank != 1:
293 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100294 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100295 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000296 elif error_name == ErrorIf.BroadcastShapesMismatch:
297 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100298 else:
299 shape_bcast[fuzz_idx] = 1
300
301 shape_list.append(shape_bcast)
302
303 return shape_list
304
305 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100306 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000307 pl, const = op["operands"]
308 num_shapes = pl + const
309 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100310 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000311 )
312
313 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100314 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000315 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100316 shape_list = TosaTensorGen._get_broadcast_shapes(
317 testGen, rng, 2, rank, error_name
318 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000319 # Add a single dimension tensor for shift
320 shape_list.append([1])
321 return shape_list
322
323 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100325 pl, const = op["operands"]
326
327 if error_name != ErrorIf.WrongRank:
328 assert rank == 4
329
330 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100331 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000332 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100333
334 # Constrict the overall size of the shape when creating ERROR_IF tests
335 if error_name:
336 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
337 ifm_shape, max_dim=24, max_items=10000
338 )
339
340 # Get the filter height/width from the operator parameters
341 filter_hw = op["filter"]
342
343 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100344 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100345
346 # The filter dimensions are OHWI
347 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
348
Jeremy Johnson5e36bde2024-03-14 16:56:10 +0000349 # The bias is OC or 1 if broadcastable
350 try:
351 if op["broadcastable_bias"]:
352 if rng.choice([True, False]):
353 ofm_depth = 1
354 except KeyError:
355 pass
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100356 bias_shape = np.asarray([ofm_depth])
357
358 return [ifm_shape, filter_shape, bias_shape]
359
360 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100361 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100362 pl, const = op["operands"]
363
364 if error_name != ErrorIf.WrongRank:
365 assert rank == 5
366
367 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100368 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000369 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100370
371 # Constrict the overall size of the shape when creating ERROR_IF tests
372 if error_name:
373 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
374 ifm_shape, max_dim=24, max_items=10000
375 )
376
377 # Get the filter depth/height/width from the operator parameters
378 filter_dhw = op["filter"]
379
380 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100381 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100382
383 # The filter dimensions are ODHWI
384 filter_shape = np.asarray(
385 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
386 )
387
388 # The bias is OC
389 bias_shape = np.asarray([ofm_channel])
390
391 return [ifm_shape, filter_shape, bias_shape]
392
393 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100394 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100395 pl, const = op["operands"]
396
397 if error_name != ErrorIf.WrongRank:
398 assert rank == 4
399
400 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100401 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 filter_hw = op["filter"]
412
413 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100414 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100415
416 # The filter dimensions are OHWI
417 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
418
419 # The bias is OC
420 bias_shape = np.asarray([ofm_depth])
421
422 return [ifm_shape, filter_shape, bias_shape]
423
424 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100425 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100426 pl, const = op["operands"]
427
428 if error_name != ErrorIf.WrongRank:
429 assert rank == 4
430 assert pl == 1 and const == 2
431
432 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100433 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000434 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100435
436 # Constrict the overall size of the shape when creating ERROR_IF tests
437 if error_name:
438 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
439 ifm_shape, max_dim=24, max_items=10000
440 )
441
442 # Get the filter height/width from the operator parameters
443 # Filter is KH, HW, C, M
444 filter_hw = op["filter"]
445
446 # Generate a random OFM depth, but don't let it get too big because
447 # the output depth is M * C
448 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100449 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100450 ) + 1
451
452 # The filter dimensions are HWCM
453 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
454
455 # The bias is M * C
456 bias_shape = np.asarray([ifm_shape[3] * filter_m])
457
458 return [ifm_shape, filter_shape, bias_shape]
459
460 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100461 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000462 pl, const = op["operands"]
463
464 if error_name != ErrorIf.WrongRank:
465 assert rank == 3
466 assert pl == 2 and const == 0
467
468 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100469 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000470
471 # Select nearest lower power of two from input height and width
472 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
473 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
474
475 # Constrict the overall size of the shape when creating ERROR_IF tests
476 if error_name:
477 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
478
479 # Generate an invalid kernel that is not a power of two
480 if error_name == ErrorIf.KernelNotPowerOfTwo:
481 inc_h = 2 if ifm_shape[1] == 1 else 1
482 inc_w = 2 if ifm_shape[2] == 1 else 1
483 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100484 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000485 ifm_shape[1] += selected_inc[0]
486 ifm_shape[2] += selected_inc[1]
487
488 ifm_shape = testGen.constrictBatchSize(ifm_shape)
489
490 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
491 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100492 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000493 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100494 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000495 ifm_shapes[modify_shape][modify_dim] *= 2
496
497 return [ifm_shapes[0], ifm_shapes[1]]
498
499 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100500 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000501 pl, const = op["operands"]
502
503 if error_name != ErrorIf.WrongRank:
504 assert rank == 3
505 assert pl == 1 and const == 0
506
507 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100508 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000509
510 # Select nearest lower power of two from input height and width
511 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
512 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
513
514 # Constrict the overall size of the shape when creating ERROR_IF tests
515 if error_name:
516 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
517
518 # Generate an invalid kernel that is not a power of two
519 if error_name == ErrorIf.KernelNotPowerOfTwo:
520 # We must increment by 2 if current size is 1
521 inc_h = 2 if ifm_shape[1] == 1 else 1
522 inc_w = 2 if ifm_shape[2] == 1 else 1
523 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100524 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000525 ifm_shape[1] += selected_inc[0]
526 ifm_shape[2] += selected_inc[1]
527
James Ward30124a82023-02-02 14:56:33 +0000528 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000529
530 return [ifm_shape]
531
532 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100533 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100534 pl, const = op["operands"]
535
536 if error_name != ErrorIf.WrongRank:
537 assert rank == 2
538
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100539 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100540
541 # Constrict the overall size of the shape when creating ERROR_IF tests
542 if error_name:
543 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
544
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100545 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100546 low=testGen.args.tensor_shape_range[0],
547 high=testGen.args.tensor_shape_range[1],
548 size=1,
549 )[0]
550 filter_shape = np.asarray([filter_oc, input_shape[1]])
551
552 bias_shape = np.asarray([filter_oc])
553
554 return [input_shape, filter_shape, bias_shape]
555
556 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100557 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100558 pl, const = op["operands"]
559
560 if error_name != ErrorIf.WrongRank:
561 assert rank == 3
562 assert pl == 2 and const == 0
563
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100564 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100565
566 # Constrict the overall size of the shape when creating ERROR_IF tests
567 if error_name:
568 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
569
570 # Get a random number for b_oc even if target shape is defined
571 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100572 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100573 low=testGen.args.tensor_shape_range[0],
574 high=testGen.args.tensor_shape_range[1],
575 size=1,
576 )
577 )[0]
578 # If N or H is large let b_oc be 1 to reduce output tensor size
579 if max(a_shape) > 1000:
580 b_oc = 1
581
582 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
583 return [a_shape, b_shape]
584
585 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100586 def tgConcat(testGen, rng, op, rank, error_name=None):
587 pl, const = op["operands"]
588 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100589
590 # Create extra tensors to concat.
591 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100592 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100593 shape_list = []
594 for i in range(pl + const + num_tensors):
595 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100596 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100597 wrongShape = shape.copy()
598
599 if remove and len(shape) > 1:
600 wrongShape = wrongShape[1:]
601 else:
602 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100603 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100604
605 shape_list.append(wrongShape)
606 else:
607 shape_list.append(shape.copy())
608
609 return shape_list
610
611 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100612 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100613 if error_name in [
614 ErrorIf.AxisSmallerZero,
615 ErrorIf.AxisLargerRank,
616 ErrorIf.ConcatInputRankMismatch,
617 ]:
618 return shapeList
619
620 # Split concat shape along axis to allow for multiple const inputs
621 # without making too many large tensors
622 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
623 # If axis can't be split we still need to invalidate other dimensions
624 if error_name == ErrorIf.ConcatInputDimMismatch:
625 for shape in shapeList[1:]:
626 # Negative test shapeLists are created individually for each test,
627 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100628 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100629 return shapeList
630
631 # Create copy of shape we are going to split (so we don't alter shapeList)
632 shape = shapeList[0].copy()
633 # Add original shape as first input
634 new_shapeList = [shape.copy()]
635 length_on_axis = shape[axis]
636 remaining_length = length_on_axis
637 for i in range(len(shapeList) - 2):
638 # Calculate split on axis and remaining value
639 split_shape_val = int(shape[axis] / 2)
640 remaining_length = remaining_length - split_shape_val
641
642 # Append new shape, and set remaining shape
643 shape[axis] = split_shape_val
644 new_shapeList.append(shape.copy())
645
646 # invalidate dimensions
647 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100648 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100649 else:
650 shape[axis] = remaining_length
651
652 if i == len(shapeList) - 3:
653 new_shapeList.append(shape.copy())
654
655 return new_shapeList
656
657
658class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100659 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100660
661 def __init__(self):
662 pass
663
Jeremy Johnson1271c442023-09-05 11:39:26 +0100664 class TVGInfo:
665 """Enhanced tensor values information including data gen dict."""
666
667 def __init__(self, tensorList, dataGenDict):
668 self.tensorList = tensorList
669 self.dataGenDict = dataGenDict
670
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100671 # Default high value for random numbers
672 TVG_FLOAT_HIGH_VALUE = {
673 DType.FP32: (1 << 128) - (1 << (127 - 23)),
674 DType.FP16: (1 << 16) - (1 << (15 - 10)),
675 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000676 DType.FP8E4M3: 448,
677 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100678 }
679
Jeremy Johnson30476252023-11-20 16:15:30 +0000680 # Default lowest normal values for random numbers
681 TVG_FLOAT_LOW_VALUE = {
682 DType.FP32: np.exp2(-126),
683 DType.FP16: np.exp2(-14),
684 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000685 DType.FP8E4M3: np.exp2(-9),
686 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000687 }
688
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100689 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100690 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000691 # Return a tuple of (low,high) data range values for the given data
692 # type using a combination of per operator table limits, data limits
693 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000694 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100695 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000696 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000697 if lowValueLookup is not None and dtype in lowValueLookup:
698 low_val = lowValueLookup[dtype]
699 else:
700 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000701 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000702 # respecting the default ranges if more/less than the low/high
703 # values
704 data_range = (
705 max(low_val, type_range[0]),
706 min(high_val, type_range[1]),
707 )
708 if data_range[0] > data_range[1]:
709 # Invalid data range from low to high created due to user
710 # constraints revert to using internal ranges as they are
711 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000712 logger.info(
713 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
714 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000715 data_range = (low_val, high_val)
716 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000717 return None
718
719 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100720 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100721 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100722 ):
723 # Variable inputs versus constants
724 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000725 if "p_count" in argsDict:
726 # Override for operators like CONCAT
727 pCount = argsDict["p_count"]
728 cCount = argsDict["c_count"]
729 assert pCount + cCount == len(
730 shapeList
731 ), "Placeholders & Constant tensors must match shapes list"
732
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000733 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100734
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100735 if (
736 error_name is not None
737 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100738 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100739 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000740 # Fall back to internal data gen when dealing with unsupported types or ops
741 data_range = argsDict["data_range"] if "data_range" in argsDict else None
742 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000743 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000744 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000745 if "data_range_list" in argsDict:
746 data_range = argsDict["data_range_list"][idx]["range"]
747 roundMode = (
748 "round" in argsDict["data_range_list"][idx]
749 and argsDict["data_range_list"][idx]["round"] is True
750 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000751 if data_range is not None and dtype not in (
752 DType.FP16,
753 DType.FP32,
754 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000755 DType.FP8E4M3,
756 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000757 ):
758 # Change from inclusive to exclusive range
759 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000760
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100761 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000762 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000763 if dtype == DType.SHAPE:
764 arr = np.int64(argsDict["fixed_data"][idx])
765 elif dtype == DType.INT8:
766 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000767 elif dtype == DType.INT16:
768 arr = np.int16(argsDict["fixed_data"][idx])
769 elif dtype == DType.INT32:
770 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000771 else:
772 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000773 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100774 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000775 if roundMode:
776 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000777 if idx < pCount:
778 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
779 else:
780 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100781
Jeremy Johnson1271c442023-09-05 11:39:26 +0100782 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
783
784 # Create data generator meta-data
785 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100786 tens_data = {
787 "version": "0.1",
788 "tensors": {},
789 }
790 dg_tens_meta = tens_data["tensors"]
evacha014a205112024-03-08 16:39:24 +0000791
792 fp_special_info = {}
793 fp_special_info["start_idx"] = int(rng.randInt())
794
Jeremy Johnson1271c442023-09-05 11:39:26 +0100795 for idx, shape in enumerate(shapeList):
796
797 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000798 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
799 tens_meta["generator"] = gtu.DataGenType(
800 gtu.DataGenType.FIXED_DATA
801 ).name
802 else:
803 tens_meta["generator"] = gtu.DataGenType(dg_type).name
804
Jeremy Johnson1271c442023-09-05 11:39:26 +0100805 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
806 tens_meta["shape"] = [int(i) for i in shape]
807 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100808 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100809
Jeremy Johnsonc870d1e2024-04-08 16:17:47 +0100810 if testGen.args.random_const_inputs:
811 # Choose type of tensor biased by defaults
812 percentage = rng.randInt(0, 100)
813 variable = (idx < pCount and percentage < 70) or (
814 idx >= pCount and percentage >= 70
815 )
816 else:
817 # Use default set up of constants versus inputs for the op
818 variable = idx < pCount
819
820 if variable:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100821 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100822 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100823 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100824
825 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
826 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000827 if (
828 tens_meta["generator"]
829 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
830 ):
831 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
832 tens_meta["fixed_data_info"] = info
833 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100834 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000835
Won Jeon64e4bfe2024-01-18 06:31:55 +0000836 data_range = None
837 if "data_range_list" in argsDict:
838 data_range = argsDict["data_range_list"][idx]["range"]
839 if "round" in argsDict["data_range_list"][idx]:
840 info["round"] = argsDict["data_range_list"][idx]["round"]
841 elif "data_range" in argsDict:
842 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000843
Won Jeon64e4bfe2024-01-18 06:31:55 +0000844 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100845 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000846 info["range"] = [str(v) for v in data_range]
847 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100848 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
849 info = {}
850 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100851 info["ks"] = int(argsDict["ks"])
852 if "acc_type" in argsDict:
853 # Convert type number into JSON name
854 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
855 "json"
856 ]
857 if "kernel" in argsDict:
858 info["kernel"] = [int(k) for k in argsDict["kernel"]]
859 if "axis" in argsDict:
860 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100861 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000862 elif dg_type == gtu.DataGenType.FULL_RANGE:
863 info = {}
864 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100865 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000866 )
867 tens_meta["full_range_info"] = info
evacha014a205112024-03-08 16:39:24 +0000868 elif dg_type == gtu.DataGenType.FP_SPECIAL:
869 tens_meta["fp_special_info"] = fp_special_info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100870 else:
871 # TODO - other data gen type
872 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100873
874 # Using the finished generate config meta data - generate the data if
875 # needed and assign a tensor name from the serializer
876
877 # Need to generate data when not lazy or for the bias tensor as we need
878 # to work out if the bias data is non-zero for compliance
879 if not testGen.args.lazy_data_gen or (
880 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
881 ):
882 # Give this tensor a temporary name until we get one from the serializer
883 temp_name = f"placeholder_{idx}"
884 dg_tens_meta[temp_name] = tens_meta
885 # Create data now using the temporary name to access meta details
886 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000887 if tens_meta["data_type"] == "SHAPE":
888 # Tensor type SHAPE and Numpy file type must be the same
889 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100890 # Remove the item as we will give it the correct name later
891 del dg_tens_meta[temp_name]
892
893 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
894 # The KS value used by compliance verification is altered when the
895 # bias data is non-zero
896 if max(abs(data)) > 0.0:
897 argsDict["ksb"] = argsDict["ks"] + 1
898
899 if testGen.args.lazy_data_gen:
900 data = None
901
Jeremy Johnsonc870d1e2024-04-08 16:17:47 +0100902 if variable:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100903 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
904 else:
905 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
906
907 tens_ser_list.append(tens)
908 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100909 dg_tens_meta[tens.name] = tens_meta
910
Jeremy Johnson1271c442023-09-05 11:39:26 +0100911 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
912
913 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100914 def tvgNegate(
915 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
916 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100917 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000918 # Integer test
919 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 pCount, cCount = op["operands"]
921 assert (
922 pCount == 1 and cCount == 0
923 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100924 # Must create tensors with values within accumulator (int32) negatable
925 # range
926 max_val = (1 << 31) - 1
927 min_val = -max_val
Jeremy Johnson862c0072024-04-16 12:23:44 +0100928 arr = rng.randTensor(
929 shapeList[0], dtypeList[0], data_range=(min_val, (max_val + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100930 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000931 tens_ser_list = []
932 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100933 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
934 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000935 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100936 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000937 # ERROR_IF or floating point test
938 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100939 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100940 )
941
Jeremy Johnson30476252023-11-20 16:15:30 +0000942 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000943 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
944 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
945 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
946 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
947 }
948
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100949 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100950 def tvgAddSub(
951 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
952 ):
Won Jeon74342e52024-01-09 00:34:40 +0000953 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000954 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100955 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000956 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100957 pCount, cCount = op["operands"]
958 assert (
959 pCount == 2 and cCount == 0
960 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000961 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000962 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
Jeremy Johnson32bf9012024-03-20 16:32:23 +0000963 data_range = None # Use default
964 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE):
965 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100966 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
967 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100968 if add:
969 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
970 else:
971 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
972
973 # Work out the saturation limits
974 max_i32 = (1 << 31) - 1
975 min_i32 = -(1 << 31)
976 max_arr = np.full(shapeList[1], max_i32)
977 min_arr = np.full(shapeList[1], min_i32)
978
979 # Find how much values exceed the maximum/minimums
980 sat_max_arr = np.maximum(res_arr - max_arr, 0)
981 sat_min_arr = np.minimum(res_arr - min_arr, 0)
982
983 if not add:
984 # Swap saturation values and negate values as we need to perform opposite operations
985 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
986
987 # Create new array of unsaturated values by clipping values as needed
988 b_unsat_arr = b_arr
989 if (sat_max_arr != 0).any():
990 # Clip values that cause saturation
991 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
992 # Reduce axes in unsaturated tensor to match original tensor
993 for axis, dim in enumerate(b_arr.shape):
994 if dim != b_unsat_arr.shape[axis]:
995 assert (
996 dim == 1
997 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
998 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
999
1000 if (sat_min_arr != 0).any():
1001 # Clip values that cause saturation
1002 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
1003 # Reduce axes in unsaturated tensor to match original tensor
1004 for axis, dim in enumerate(b_arr.shape):
1005 if dim != b_unsat_arr.shape[axis]:
1006 assert (
1007 dim == 1
1008 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
1009 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
1010
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001011 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001012 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1013 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001014 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001015 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
1016 )
1017
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001018 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001020 # ERROR_IF or floating point test
1021 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001022 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001023 )
1024 if data_range:
1025 argsDict["data_range"] = data_range
1026
1027 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001028 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001029 )
1030
1031 @staticmethod
1032 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001033 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001034 ):
1035 if dtypeList[0] in (
1036 DType.INT32,
1037 DType.INT16,
1038 DType.INT8,
1039 ):
1040 # Limit input tensors with cond_if_binary or while_loop to stop
1041 # saturation of add/sub ops with int32 and keep all logical shift
1042 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001043 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044 pCount, cCount = op["operands"]
1045 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001046 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001047 for idx, shape in enumerate(shapeList[:]):
1048 if dtypeList[0] == DType.INT32:
Jeremy Johnson862c0072024-04-16 12:23:44 +01001049 # Limit data range to avoid saturation
1050 arr = np.int32(rng.randTensor(shapeList[idx], DType.INT16))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001051 else:
Jeremy Johnson862c0072024-04-16 12:23:44 +01001052 arr = rng.randTensor(
1053 shapeList[idx], dtypeList[0], data_range=(0, 32)
1054 )
1055
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001056 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001057 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001058 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1059 )
1060 pRemain -= 1
1061 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001062 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001063 testGen.ser.addConst(shape, dtypeList[idx], arr)
1064 )
1065
Jeremy Johnson587cc842024-02-08 11:45:44 +00001066 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001067 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001068 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001069 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001070 )
1071
1072 @staticmethod
1073 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001074 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001075 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001076 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001077 pCount, cCount = op["operands"]
1078 # Force value of operand[1] to be within [0, num_bits]
1079 assert (
1080 pCount == 2 and cCount == 0
1081 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1082
Jeremy Johnson587cc842024-02-08 11:45:44 +00001083 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001084 for idx, shape in enumerate(shapeList[:]):
1085 if idx == 1:
1086 if dtypeList[idx] == DType.INT8:
Jeremy Johnson862c0072024-04-16 12:23:44 +01001087 arr = rng.randTensor(shape, dtypeList[idx], data_range=(0, 8))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001088 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson862c0072024-04-16 12:23:44 +01001089 arr = rng.randTensor(shape, dtypeList[idx], data_range=(0, 16))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001090 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson862c0072024-04-16 12:23:44 +01001091 arr = rng.randTensor(shape, dtypeList[idx], data_range=(0, 32))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001092 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson862c0072024-04-16 12:23:44 +01001093 arr = rng.randTensor(shape, DType.INT32, data_range=(0, 8))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001094 else:
1095 raise Exception("OpArithmeticRightShift: invalid input dtype")
1096 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001097 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001098 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001099
Jeremy Johnson587cc842024-02-08 11:45:44 +00001100 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001101
1102 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001103 def tvgReshape(
1104 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1105 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001106 dtypeList[1] = DType.SHAPE
1107 shapeList[1] = [len(argsDict["new_shape"])]
1108 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1109 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1110
1111 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001112 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001113 )
1114
1115 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001116 def tvgRescale(
1117 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1118 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001119 scale32 = argsDict["scale"]
1120 multiplier_arr = argsDict["multiplier"]
1121 shift_arr = argsDict["shift"]
1122
1123 if scale32:
1124 dtypeList[1] = DType.INT32
1125 else:
1126 dtypeList[1] = DType.INT16
1127 shapeList[1] = [len(multiplier_arr)]
1128 dtypeList[2] = DType.INT8
1129 shapeList[2] = [len(shift_arr)]
1130 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1131 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1132
1133 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001134 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001135 )
1136
1137 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001138 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001139 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1140 pad_values = argsDict["pad"].flatten()
1141 dtypeList[1] = DType.SHAPE
1142 shapeList[1] = [len(pad_values)]
1143 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1144 argsDict["fixed_data"] = [None, pad_values]
1145
1146 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001147 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001148 )
1149
1150 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001151 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001152 dtypeList[1] = DType.SHAPE
1153 shapeList[1] = [len(argsDict["start"])]
1154 dtypeList[2] = DType.SHAPE
1155 shapeList[2] = [len(argsDict["size"])]
1156 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1157 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1158
1159 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001160 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001161 )
1162
1163 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001164 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001165 dtypeList[1] = DType.SHAPE
1166 shapeList[1] = [len(argsDict["multiples"])]
1167 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1168
1169 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001170 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001171 )
1172
1173 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001174 def tvgSelect(
1175 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1176 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001177 # Set datatype of condition tensor to boolean
1178 dtypeList[0] = DType.BOOL
1179
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001180 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001181 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001182 )
1183
1184 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001185 def tvgIntDiv(
1186 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1187 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001188 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001189 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001190 pCount, cCount = op["operands"]
1191 assert (
1192 pCount == 2 and cCount == 0
1193 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1194
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001195 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001196
1197 # Two invalid cases for Op.INTDIV:
1198 # 1. divisor == 0
1199 # 2. dividend == -(1<<31) and divisor == -1
1200 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001201 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1202 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001203
1204 if (divisor_arr == 0).any():
1205 continue
1206
1207 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1208 continue
1209
1210 break
1211
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001212 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001213 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1214 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001215 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1217 )
1218
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001219 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001220 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001221 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001222 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001223 )
1224
Jeremy Johnson30476252023-11-20 16:15:30 +00001225 # Set the MUL data range to the square root of the largest value
1226 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001227 TVG_FLOAT_HIGH_VALUE_MUL = {
1228 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1229 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1230 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1231 }
1232
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001233 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001234 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001235 if error_name is not None or dtypeList[0] in (
1236 DType.FP16,
1237 DType.BF16,
1238 DType.FP32,
1239 ):
1240 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001241 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001242 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001243 )
1244 if data_range:
1245 argsDict["data_range"] = data_range
1246
Jeremy Johnson0a042992024-02-28 13:20:05 +00001247 if dtypeList[0] != DType.SHAPE:
1248 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1249 dtypeList[2] = DType.INT8
1250 shapeList[2] = [1]
1251 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1252 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1253
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001254 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001255 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001256 )
1257 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001258 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001259 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001260
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001261 tens_ser_list = []
1262
1263 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001264 if dtypeList[0] == DType.SHAPE:
1265 shift = 0
1266 else:
1267 shift = argsDict["shift"]
Jeremy Johnson862c0072024-04-16 12:23:44 +01001268
1269 np_type = np.int32
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001270 if dtypeList[0] == DType.INT8:
1271 num_bits = 8
Jeremy Johnson862c0072024-04-16 12:23:44 +01001272 np_type = np.int8
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001273 elif dtypeList[0] == DType.INT16:
1274 num_bits = 16
Jeremy Johnson862c0072024-04-16 12:23:44 +01001275 np_type = np.int16
Won Jeon74342e52024-01-09 00:34:40 +00001276 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001277 num_bits = 32
Jeremy Johnson862c0072024-04-16 12:23:44 +01001278 # np_type is not used for DType.SHAPE so leave as np.int32
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001279 elif error_name == ErrorIf.WrongInputType:
1280 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001281 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001282 raise Exception(
1283 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1284 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001286 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001287 if dtypeList[idx] == DType.SHAPE:
1288 low = testGen.args.tensor_shape_range[0]
1289 high = testGen.args.tensor_shape_range[1]
1290 else:
1291 low = -(2 ** (num_bits - 1))
1292 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001293
Jeremy Johnson862c0072024-04-16 12:23:44 +01001294 a_arr = rng.randTensor(
1295 shapeList[0], DType.INT32, data_range=(low, high)
1296 )
1297 b_arr = rng.randTensor(
1298 shapeList[1], DType.INT32, data_range=(low, high)
1299 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001300
1301 i = 0
1302 while True:
1303
1304 a_arr_64 = a_arr.astype(np.int64)
1305 b_arr_64 = b_arr.astype(np.int64)
1306
1307 if shift > 0:
1308 rounding = 1 << (shift - 1)
1309 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001310 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001311 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001312
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001313 if (result_arr > -(2**31)).all() and (
1314 result_arr <= ((2**31) - 1)
1315 ).all():
1316 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001317
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001318 i = i + 1
1319 a_arr = a_arr // 2
1320 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001321
Won Jeon74342e52024-01-09 00:34:40 +00001322 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001323 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001324 tens_ser_list.append(
1325 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1326 )
1327 tens_ser_list.append(
1328 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1329 )
1330 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001331 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001332 tens_ser_list.append(
Jeremy Johnson18a379d2024-03-28 15:53:21 +00001333 testGen.ser.addPlaceholder(
Jeremy Johnson862c0072024-04-16 12:23:44 +01001334 shapeList[0], dtypeList[0], a_arr.astype(np_type)
Jeremy Johnson18a379d2024-03-28 15:53:21 +00001335 )
Won Jeon74342e52024-01-09 00:34:40 +00001336 )
1337 tens_ser_list.append(
Jeremy Johnson18a379d2024-03-28 15:53:21 +00001338 testGen.ser.addPlaceholder(
Jeremy Johnson862c0072024-04-16 12:23:44 +01001339 shapeList[1], dtypeList[1], b_arr.astype(np_type)
Jeremy Johnson18a379d2024-03-28 15:53:21 +00001340 )
Won Jeon74342e52024-01-09 00:34:40 +00001341 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001342 tens_ser_list.append(
1343 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1344 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001345
1346 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001347
1348 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001349 def tvgConcat(
1350 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1351 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001352 count = len(shapeList) - testGen.args.num_const_inputs_concat
1353 if count < 1:
1354 count = 1
1355 if testGen.args.num_const_inputs_concat == 0:
1356 count = len(shapeList)
1357
Won Jeon74342e52024-01-09 00:34:40 +00001358 op = testGen.TOSA_OP_LIST[opName]
1359 if op["op"] == Op.CONCAT_SHAPE:
1360 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001361 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001362 else:
1363 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001364 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001365 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001366
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001367 # Override default pCount/cCount for operator
1368 argsDict["p_count"] = count
1369 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001370
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001371 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001372 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001373 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001374
1375 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001376 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001377 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001378 ):
1379 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001380 pCount, cCount = op["operands"]
1381 assert (
1382 pCount == 2 and cCount == 0
1383 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001384 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
Jeremy Johnson862c0072024-04-16 12:23:44 +01001385 shift_arr = rng.randTensor(shapeList[1], dtypeList[0], data_range=(0, 32))
1386
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001387 tens_ser_list = []
1388 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001389 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1390 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001391 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001392 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1393 )
1394
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001395 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001396
1397 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001398 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001399 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1400 # Integer
1401 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001402 pCount, cCount = op["operands"]
1403 assert (
1404 pCount == 2 and cCount == 0
1405 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001406
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001407 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1408 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001409
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001410 # Using random numbers means that it will be very unlikely that
1411 # there are any matching (equal) values, therefore force that
1412 # there are twice the number of matching values as the tensor rank
1413 for num in range(0, len(shapeList[0]) * 2):
1414 a_index = []
1415 b_index = []
1416 # Choose an index in each axis for the whole shape
1417 for axis in range(0, len(shapeList[0])):
1418 # Index can be up to the largest dimension in both shapes
1419 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001420 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001421 )
1422 # Reduce the index down to a shape's dim for broadcasting
1423 a_index.append(min(shapeList[0][axis] - 1, index))
1424 b_index.append(min(shapeList[1][axis] - 1, index))
1425
1426 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1427
Jeremy Johnsona0150012023-11-15 15:52:06 +00001428 tens_ser_list = []
1429 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001430 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1431 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001432 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001433 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1434 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001435 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001436 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001437 # ERROR_IF or floating point test
1438 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001439 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001440 )
1441
1442 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001443 def tvgReduceSum(
1444 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1445 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001446 dtype = dtypeList[0]
1447 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001448 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001449 pCount, cCount = op["operands"]
1450 assert (
1451 pCount == 1 and cCount == 0
1452 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1453 # Limit values so that the sum cannot exceed the range of an int32 during
1454 # summation of any axis
1455 range_val = int((1 << 31) / max(shapeList[0]))
Jeremy Johnson862c0072024-04-16 12:23:44 +01001456 values_arr = rng.randTensor(
1457 shapeList[0], dtype, data_range=(-range_val, range_val)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001458 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001459 tens_ser_list = []
1460 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001461 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001462 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001463 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001464 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001465 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001466 if (
1467 error_name is None
1468 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1469 ):
1470 # Limit ranges for (non error & non compliance) tests by using
1471 # values that can be summed on any axis to not hit infinity
1472 highval_lookup = {
1473 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1474 / max(shapeList[0])
1475 }
1476 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001477 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001478 )
1479 assert data_range is not None
1480 argsDict["data_range"] = data_range
1481
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001482 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001483 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001484 )
1485
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001486 @staticmethod
1487 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001488 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001489 ):
1490 dtype = dtypeList[0]
1491 if error_name is None:
1492 # Limit ranges for (non error) tests by using
1493 # values that can be multiplied on any axis to not hit infinity
1494 highval_lookup = {
1495 dtype: math.pow(
1496 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1497 1 / max(shapeList[0]),
1498 )
1499 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001500 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001501 assert data_range is not None
1502 argsDict["data_range"] = data_range
1503
1504 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001505 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001506 )
1507
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001508 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001509 def tvgResize(
1510 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1511 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001512 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001513 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001514 dtypeList[0],
1515 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1516 )
1517 if data_range:
1518 argsDict["data_range"] = data_range
1519 # Needed for compliance
1520 argsDict["max_abs_value"] = data_range[1]
1521
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001522 scale_values = argsDict["scale"]
1523 offset_values = argsDict["offset"]
1524 border_values = argsDict["border"]
1525 dtypeList[1] = DType.SHAPE
1526 dtypeList[2] = DType.SHAPE
1527 dtypeList[3] = DType.SHAPE
1528 shapeList[1] = [len(scale_values)]
1529 shapeList[2] = [len(offset_values)]
1530 shapeList[3] = [len(border_values)]
1531 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1532
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001533 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001534 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001535 )
1536
Jeremy Johnson30476252023-11-20 16:15:30 +00001537 # Set the POW exponent high data range
1538 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1539 DType.FP32: 10.0,
1540 DType.FP16: 10.0,
1541 DType.BF16: 10.0,
1542 }
1543 # POW highest base value (within a safe margin of error) that can be raised
1544 # to +ve exponent that doesn't become Infinity
1545 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1546 DType.FP32: math.floor(
1547 math.pow(
1548 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1549 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1550 )
1551 ),
1552 DType.FP16: math.floor(
1553 math.pow(
1554 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1555 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1556 )
1557 ),
1558 DType.BF16: math.floor(
1559 math.pow(
1560 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1561 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1562 )
1563 ),
1564 }
1565 # POW lowest base value (within a safe margin of error) that can be raised
1566 # to -ve exponent that doesn't become Infinity
1567 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1568 DType.FP32: math.ceil(
1569 math.pow(
1570 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1571 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1572 )
1573 * 1000
1574 )
1575 / 1000,
1576 DType.FP16: math.ceil(
1577 math.pow(
1578 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1579 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1580 )
1581 * 1000
1582 )
1583 / 1000,
1584 DType.BF16: math.ceil(
1585 math.pow(
1586 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1587 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1588 )
1589 * 1000
1590 )
1591 / 1000,
1592 }
1593
1594 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001595 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001596 if error_name is not None:
1597 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001598 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001599 )
1600 dtype = dtypeList[0]
1601 # Different ranges for POW
1602 test_set = argsDict["s"]
1603 if test_set == 0:
1604 # Positive base with fractional exponent
1605 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001606 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001607 dtype,
1608 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1609 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1610 )
1611 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001612 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001613 )
1614 exp_round = False
1615 else:
1616 # Integer exponent
1617 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001618 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001619 )
1620 exp_round = True
1621 if test_set == 1:
1622 # Positive base
1623 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001624 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001625 dtype,
1626 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1627 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1628 )
1629 else:
1630 assert test_set == 2
1631 # Negative base
1632 # Supply new look up tables with negative values
1633 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001634 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001635 dtype,
1636 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1637 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1638 )
1639
1640 data_range_list = (
1641 {
1642 "range": base_range,
1643 },
1644 {
1645 "range": exp_range,
1646 "round": exp_round,
1647 },
1648 )
1649 argsDict["data_range_list"] = data_range_list
1650 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001651 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001652 )
1653
1654 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001655 def tvgLogRsqrt(
1656 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1657 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001658 # LOG & RSQRT data range from lowest expressible positive number to
1659 # largest to avoid NaNs
1660 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001661 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001662 dtypeList[0],
1663 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1664 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1665 )
1666 if data_range:
1667 argsDict["data_range"] = data_range
1668
1669 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001670 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001671 )
1672
1673 # Set the EXP data range to the log of the largest to smallest values
1674 # to avoid infinities or making the result zero
1675 TVG_FLOAT_HIGH_VALUE_EXP = {
1676 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1677 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1678 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1679 }
1680 TVG_FLOAT_LOW_VALUE_EXP = {
1681 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1682 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1683 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1684 }
1685
1686 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001687 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001688 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001689 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001690 dtypeList[0],
1691 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1692 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1693 )
1694 if data_range:
1695 argsDict["data_range"] = data_range
1696
1697 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001698 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001699 )
1700
1701 @staticmethod
1702 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001703 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001704 ):
1705 dtype = dtypeList[0]
1706 if (
1707 error_name is None
1708 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001709 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001710 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001711 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001712 # Limit ranges for (non error & non compliance) FP tests by using
1713 # values that can be multiplied on any axis to not hit infinity/NaN
1714 IC = shapeList[0][1]
1715 highval_lookup = {
1716 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1717 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001718 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001719 assert data_range is not None
1720 argsDict["data_range"] = data_range
1721
1722 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001723 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001724 )
1725
Jeremy Johnson708da822023-11-15 16:25:45 +00001726 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001727 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001728 in_dtype = dtypeList[0]
1729 out_dtype = argsDict["out_type"]
1730 # Create look up to limit input tensor to output type maximums to avoid
1731 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001732 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001733 highval_lookup = {in_dtype: out_range[1]}
1734 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001735 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001736 in_dtype,
1737 highval_lookup,
1738 )
1739
1740 assert data_range is not None
1741 argsDict["data_range"] = data_range
1742
1743 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001744 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001745 )
1746
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001747 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001748 def tvgGather(
1749 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1750 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001751 K = shapeList[0][1]
1752
1753 # Fix the type of the indices tensor
1754 dtypeList[1] = DType.INT32
1755
1756 dtype = dtypeList[0]
1757 if not gtu.dtypeIsSupportedByCompliance(dtype):
1758 # Test unsupported by data generator
1759 op = testGen.TOSA_OP_LIST[opName]
1760 pCount, cCount = op["operands"]
1761 assert (
1762 pCount == 2 and cCount == 0
1763 ), "Op.GATHER must have 2 placeholders, 0 consts"
1764
1765 tens_ser_list = []
1766 for idx, shape in enumerate(shapeList):
1767 dtype = dtypeList[idx]
1768 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001769 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001770 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1771 else:
1772 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001773 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001774 # To match old functionality - create indices as CONST
1775 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1776
1777 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1778
1779 else:
1780 # ERROR_IF or floating point test
1781 # Use inclusive values upto index K for indices tensor
1782 data_range_list = (
1783 {"range": None},
1784 {"range": (0, K - 1)},
1785 )
1786 argsDict["data_range_list"] = data_range_list
1787
1788 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001789 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001790 )
1791
1792 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001793 def tvgScatter(
1794 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1795 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001796 K = shapeList[0][1]
1797 W = shapeList[2][1]
1798
1799 # Work out an indices tensor here with data that doesn't exceed the
1800 # dimension K of the values_in tensor and does NOT repeat the same K
1801 # location as needed by the spec:
1802 # "It is not permitted to repeat the same output index within a single
1803 # SCATTER operation and so each output index occurs at most once."
1804 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1805
1806 # Fix the type of the indices tensor
1807 dtypeList[1] = DType.INT32
1808
1809 dtype = dtypeList[0]
1810 if not gtu.dtypeIsSupportedByCompliance(dtype):
1811 # Test unsupported by data generator
1812 op = testGen.TOSA_OP_LIST[opName]
1813 pCount, cCount = op["operands"]
1814 assert (
1815 pCount == 3 and cCount == 0
1816 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1817
1818 tens_ser_list = []
1819 for idx, shape in enumerate(shapeList):
1820 dtype = dtypeList[idx]
1821 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001822 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001823 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1824 else:
1825 # Create the indices array
1826 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1827 arr = []
1828 for n in range(shape[0]):
1829 # Get a shuffled list of output indices (0 to K-1) and
1830 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001831 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001832 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1833 # To match old functionality - create indices as CONST
1834 tens_ser_list.append(
1835 testGen.ser.addConst(shape, dtype, indices_arr)
1836 )
1837
1838 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1839
1840 else:
1841 # ERROR_IF or floating point test
1842 # Use inclusive values upto index K for indices tensor
1843 data_range_list = (
1844 {"range": None},
1845 {"range": (0, K - 1)},
1846 {"range": None},
1847 )
1848 argsDict["data_range_list"] = data_range_list
1849
1850 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001851 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001852 )
1853
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001854
1855class TosaArgGen:
1856 """Argument generators create exhaustive or random lists of attributes for
1857 operators that take attributes or other parameters.
1858
1859 The return value is a list of (descriptive_name, [arglist]) tuples where
1860 the descriptive_name is appended to the test name and the arglist is expanded
1861 as arguments to the operator build function.
1862 """
1863
1864 def __init__(self):
1865 pass
1866
1867 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001868 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001869 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001870 if (
1871 error_name is None
1872 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1873 and gtu.dtypeIsSupportedByCompliance(dtype)
1874 ):
evacha01ad8e1e22024-03-19 12:42:17 +00001875 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"].get(
1876 dtype, (gtu.DataGenType.PSEUDO_RANDOM,)
1877 )
1878
Jeremy Johnson1271c442023-09-05 11:39:26 +01001879 else:
1880 # Error test or No data generator types listed - assume random
1881 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1882
1883 # Expand arg list with other data generator types
1884 new_arg_list = []
1885 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001886 for arg_str, args_dict in arg_list:
evacha01ad8e1e22024-03-19 12:42:17 +00001887 gen_args_dict = args_dict.copy()
evacha014a205112024-03-08 16:39:24 +00001888 # Only create one test by default - no sets of tests
1889 num_test_sets = 0
1890
Jeremy Johnson1271c442023-09-05 11:39:26 +01001891 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001892 if error_name is None:
evacha014a205112024-03-08 16:39:24 +00001893 num_test_sets = args_dict.get("num_test_sets", 0)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001894
1895 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1896 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001897 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001898 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001899 shape_info = (
1900 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1901 if "shape" in args_dict
1902 else ""
1903 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001904 logger.info(
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00001905 f"Skipping {opName}{shape_info} {gtu.DTYPE_ATTRIBUTES[dtype]['json']} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001906 )
1907 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001908 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001909 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001910 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001911
Jeremy Johnson30476252023-11-20 16:15:30 +00001912 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1913
evacha01ad8e1e22024-03-19 12:42:17 +00001914 elif dg_type == gtu.DataGenType.FULL_RANGE:
1915 tensor_size = gtu.product(shapeList[0])
1916 if tensor_size < gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1917 shape_info = " ({})".format(shapeList[0])
1918 logger.info(
1919 f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}"
1920 )
1921 continue
evacha014a205112024-03-08 16:39:24 +00001922 # Large enough tensor data size for full range, add full test
evacha01ad8e1e22024-03-19 12:42:17 +00001923 arg_str = f"{arg_str}_full" if arg_str else "full"
1924 gen_args_dict["tags"] = args_dict.get("tags", []) + [
1925 "non_finite_fp_data"
1926 ]
1927
evacha014a205112024-03-08 16:39:24 +00001928 elif dg_type == gtu.DataGenType.FP_SPECIAL:
1929 shapes_set = {tuple(x) for x in shapeList}
1930 if len(shapes_set) != 1:
1931 logger.info(
1932 f"Changing {opName} input shapes {shapes_set} - broadcasting incompatable with special test"
1933 )
1934 shapeList = [np.int32(np.broadcast_shapes(*shapeList))] * len(
1935 shapeList
1936 )
1937 arg_str = f"{arg_str}_fs" if arg_str else "fs"
evacha01b7f5eed2024-04-18 14:49:35 +01001938 gen_args_dict["tags"] = args_dict.get("tags", []) + [
1939 "non_finite_fp_data"
1940 ]
evacha014a205112024-03-08 16:39:24 +00001941
evacha01ad8e1e22024-03-19 12:42:17 +00001942 gen_args_dict["dg_type"] = dg_type
Jeremy Johnson30476252023-11-20 16:15:30 +00001943 if num_test_sets > 0:
1944 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001945 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
evacha01ad8e1e22024-03-19 12:42:17 +00001946 set_args_dict = gen_args_dict.copy()
evacha019c96eef2024-02-07 11:21:55 +00001947 set_args_dict["s"] = s
evacha019c96eef2024-02-07 11:21:55 +00001948 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001949 else:
1950 # Default is a single test
evacha01ad8e1e22024-03-19 12:42:17 +00001951 new_arg_list.append((arg_str, gen_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001952
1953 return new_arg_list
1954
1955 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001956 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001957 """A trivial argument generator for operators that don't take any
1958 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001959 arg_list = TosaArgGen._add_data_generators(
1960 testGen,
1961 opName,
evacha019c96eef2024-02-07 11:21:55 +00001962 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001963 dtype,
1964 [("", {})],
1965 error_name,
1966 )
1967 # Return list of tuples: (arg_str, args_dict)
1968 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001969
1970 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001971 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001972 """Pow operator needs different test sets to cover random numbers
1973 without creating NaNs or Infs"""
1974 arg_list = TosaArgGen._add_data_generators(
1975 testGen,
1976 opName,
evacha019c96eef2024-02-07 11:21:55 +00001977 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001978 dtype,
1979 [("", {"num_test_sets": 3})],
1980 error_name,
1981 )
1982 # Return list of tuples: (arg_str, args_dict)
1983 return arg_list
1984
1985 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001986 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001987 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001988 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001989 shape = shapeList[0]
1990
1991 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001992 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001993 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001994 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001995 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001996 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001997 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001998 # Create tests for each dimension
1999 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002000
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002001 opid = testGen.TOSA_OP_LIST[opName]["op"]
2002
2003 for a in axes:
2004 args_dict = {"axis": int(a)}
2005 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00002006 output_shape = shape.copy()
2007 if error_name is None:
2008 # It only matters that we calculate the dot_products correctly
2009 # for non error_if tests as they should never be run
2010 output_shape[a] = 1
2011 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002012 args_dict["shape"] = shape
2013 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
2014 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
2015
2016 arg_list.append(("axis{}".format(a), args_dict))
2017
2018 arg_list = TosaArgGen._add_data_generators(
2019 testGen,
2020 opName,
evacha019c96eef2024-02-07 11:21:55 +00002021 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002022 dtype,
2023 arg_list,
2024 error_name,
2025 )
2026 # Return list of tuples: (arg_str, args_dict)
2027 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002028
2029 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002030 def _calculate_sparsity(num_tests, sparsity_factor):
2031 sparsity = num_tests // sparsity_factor + 1
2032 # If there are only a small number of tests, just select them all
2033 if sparsity < 13:
2034 sparsity = 1
2035 # To get a variety of parameter combinations sparsity should not be a
2036 # multiple of 2, 3 or 5
2037 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2038 sparsity += 1
2039 return sparsity
2040
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002041 # Maximum number of error_if variants to produce
Jeremy Johnson87460262024-03-25 09:46:02 +00002042 MAX_TESTS_ERROR_IFS = 3
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002043
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002044 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002045 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002046 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002047 arg_list = []
2048
Jeremy Johnson0c716862023-04-13 17:18:19 +01002049 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002050 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002051 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002052 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002053
Tai Lyf36f2562024-03-14 16:21:29 +00002054 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2055
2056 if error_name == ErrorIf.WrongAccumulatorType:
2057 accum_dtypes = (
2058 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2059 )
James Ward8b390432022-08-12 20:48:56 +01002060
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002061 # For op type checks
2062 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002063
2064 # Check the rank
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002065 rank = 5 if op["op"] == Op.CONV3D else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002066 if error_name != ErrorIf.WrongRank:
2067 assert len(ifm_shape) == rank
2068 assert len(filter_shape) == rank
2069
Jeremy Johnson0c716862023-04-13 17:18:19 +01002070 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002071 k_rank = rank - 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002072 k_pos = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002073 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002074 # compliance size - KS
2075 k_size = gtu.product(k_shape)
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002076 if not op["op"] == Op.DEPTHWISE_CONV2D:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002077 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002078
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002079 def get_conv_output_info(p, s, d, fix_up_padding=False):
2080 # Work out remainders and output dimensions with an
2081 # option to adjust paddings to create a valid operation
2082 nonlocal ifm_shape, k_shape, error_name, k_rank
2083 if fix_up_padding:
2084 p = list(p) # Make paddings editable
2085 outputs_no_stride = []
2086 remainders = []
2087 outputs = []
2088 for index in range(k_rank):
2089 pad_offset = index * 2
2090 fixed = False
2091 # Fix up pad values to produce valid conv2d
2092 while not fixed:
2093 # Output dimension without being adjusted for stride
2094 output_no_stride = (
2095 ifm_shape[index + 1]
2096 - 1
2097 + p[pad_offset]
2098 + p[pad_offset + 1]
2099 - (k_shape[index] - 1) * d[index]
2100 )
2101 # Tensor left over after applying striding
2102 remainder = output_no_stride % s[index]
2103 if not fix_up_padding:
2104 # Just want remainders and outputs
2105 break
2106 if output_no_stride <= 0:
2107 p[pad_offset + 1] += abs(output_no_stride) + 1
2108 continue
2109 if error_name == ErrorIf.ConvOutputShapeNonInteger:
2110 if remainder:
2111 # Conditions to trigger the test
2112 fixed = True
2113 else:
2114 p[pad_offset + 1] += 1
2115 else:
2116 if remainder:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002117 # Stride will be negative for StrideSmallerOne
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002118 assert remainder > 0 or (
2119 error_name == ErrorIf.StrideSmallerOne and remainder < 0
2120 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002121 p[pad_offset + 1] += abs(remainder)
2122 else:
2123 fixed = True
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002124 outputs_no_stride.append(output_no_stride)
2125 remainders.append(remainder)
2126 # Output dimension taking in to account stride
2127 outputs.append((output_no_stride // s[index]) + 1)
2128
2129 if fix_up_padding:
2130 p = tuple(p) # Make the paddings read-only
2131 assert min(outputs_no_stride) > 0, "Fix up did not work!"
2132 return p, remainders, outputs, outputs_no_stride
2133
2134 # Only fix up padding for conv2d and float types currently
2135 fix_up_padding = gtu.dtypeIsFloat(dtypes[0]) and op["op"] == Op.CONV2D
2136 # Allow any size of output dimension
2137 max_dim_size = None
2138 # Include all tests by default
2139 sparsity = 1
2140
2141 # Work out padding, strides and dilation ranges depending on
2142 # error and arguments
2143 if error_name in (
2144 ErrorIf.PadSmallerZero,
2145 ErrorIf.StrideSmallerOne,
2146 ErrorIf.DilationSmallerOne,
2147 ):
2148 # Use specific invalid value(s)
2149 if error_name == ErrorIf.PadSmallerZero:
2150 # Create negative paddings but with positive opposite paddings
2151 neg_pad = rng.choice(range(-5, 0))
2152 p_vals = [neg_pad, abs(neg_pad)]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002153 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002154 p_vals = [0, 0]
2155 if error_name == ErrorIf.StrideSmallerOne:
2156 # Can't use stride=0, as it is used to derive output shape, as a divisor
2157 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002158 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002159 s_vals = [1]
2160 if error_name == ErrorIf.DilationSmallerOne:
2161 d_vals = [rng.choice(range(-5, 1))]
2162 else:
2163 d_vals = [1]
2164 paddings = {tuple(p_vals) * k_rank}
2165 strides = {tuple(s_vals) * k_rank}
2166 dilations = {tuple(d_vals) * k_rank}
2167
2168 fix_up_padding = True # Need to fix up paddings to be valid
2169
2170 elif testGen.args.level8k and error_name is None:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002171 # Only test 8k levels boundaries
2172 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2173 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2174 bigPadding = bigKernel
2175
2176 dilation_shape = [1] * k_rank
2177 pad_shape = [0] * k_rank * 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002178 if op["op"] == Op.CONV3D:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002179 # Small stride apart from for big kernel (see below) to keep
2180 # tensor size/calculation small
2181 stride_shape = [1] * k_rank
2182 for idx in range(k_rank):
2183 pad_offset = idx * 2
2184 if k_shape[idx] == bigKernel:
2185 # Padding shape needs to account for tensor shape
2186 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2187 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2188 # Big stride to reduce output size
2189 stride_shape[idx] = bigKernel
2190 else:
2191 # Account for kernel size
2192 pad_shape[pad_offset] = k_shape[idx] - 1
2193 else:
2194 # Always have a large stride with extra padding and dilation to keep
2195 # tensor calculation reasonable
2196 stride_shape = [bigKernel] * k_rank
2197 for idx in range(k_rank):
2198 # Dilation shape must account for kernel size
2199 dilation_shape[idx] = bigKernel // k_shape[idx]
2200 # Padding shape needs to accommodate tensor/kernel & dilation
2201 pad_offset = idx * 2
2202 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2203 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2204
2205 strides = {tuple(stride_shape)}
2206 dilations = {tuple(dilation_shape)}
2207 paddings = {tuple(pad_shape)}
2208 # Create a limit for the output dimensions size
2209 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2210
2211 # Currently allow all combinations that are reasonable size
2212 sparsity = 1
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002213 else:
2214 # Generate comprehensive argument lists
2215 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
2216 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
2217 # Stride must be greater than 1 to force non-integer error
2218 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
2219 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
2220 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002221
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002222 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2223 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
2224
2225 if error_name is None and testGen.args.oversize:
2226 # add some oversize argument values
2227 if max(ifm_shape) < 64:
2228 bigPadding = 9
2229 paddings.update(
2230 {
2231 x
2232 for x in itertools.product(
2233 *([[0, bigPadding]] * (k_rank * 2))
2234 )
2235 }
2236 )
2237 bigStride = 8
2238 strides.update(
2239 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2240 )
2241 bigDilation = 7
2242 dilations.update(
2243 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2244 )
2245
2246 if error_name is None:
2247 # There are too many parameter combinations, so generate them sparsely,
2248 sparsity_factor = 120
2249 sparsity = TosaArgGen._calculate_sparsity(
2250 len(paddings) * len(strides) * len(dilations), sparsity_factor
2251 )
2252
2253 # Run through all the argument options creating valid test cases
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002254 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002255 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002256 for a in accum_dtypes:
2257 for s in sorted(list(strides)):
2258 for p in sorted(list(paddings)):
2259 for d in sorted(list(dilations)):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002260 if more_tests and (n % sparsity == 0):
2261 (
2262 p,
2263 remainders,
2264 outputs,
2265 outputs_no_stride,
2266 ) = get_conv_output_info(p, s, d, fix_up_padding)
2267 # Following is like checking each dimension N:
2268 # (ifm_shape[N+1] - 1 + p[N*2] + p[N*2+1]) > d[N] * (k_shape[N] - 1)
2269 if min(outputs_no_stride) <= 0:
2270 # Not a valid operation
2271 n += 1 # Increment count of tests
2272 continue
Tai Lyf36f2562024-03-14 16:21:29 +00002273
2274 if (
2275 # the parameters must produce integer exact output
2276 error_name != ErrorIf.ConvOutputShapeNonInteger
2277 and max(remainders) == 0
2278 ) or (
2279 error_name == ErrorIf.ConvOutputShapeNonInteger
2280 and max(remainders) > 0
2281 ):
2282 if (
2283 max_dim_size is not None
2284 and max(outputs) >= max_dim_size
2285 ):
2286 # Test will consume too much memory - skip it
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002287 logger.debug(
2288 "agConv: Convolution output too big - skipped"
2289 )
Tai Lyf36f2562024-03-14 16:21:29 +00002290 continue
2291
2292 # Compliance - number of dot product calculations
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002293 if op["op"] == Op.DEPTHWISE_CONV2D:
Tai Lyf36f2562024-03-14 16:21:29 +00002294 # N*OH*OW*C*M
2295 dots = gtu.product(
2296 (ifm_shape[0], *outputs, *filter_shape[2:])
2297 )
2298 else:
2299 # N*OH*OW*OC or N*OD*OH*OW*OC
2300 dots = gtu.product(
2301 (ifm_shape[0], *outputs, filter_shape[0])
2302 )
2303 args_dict = {
2304 "acc_type": a,
2305 "stride": s,
2306 "pad": p,
2307 "dilation": d,
2308 "kernel": k_shape,
2309 "ks": k_size,
2310 "dot_products": dots,
2311 "shape": ifm_shape,
2312 }
2313
2314 # Support for larger values than 9 needs different delimiter
2315 delim = "" if max(s + p + d) <= 9 else "x"
2316 arg_list.append(
2317 (
2318 "acc{}_st{}_pad{}_dilat{}".format(
2319 testGen.typeStr(a),
2320 delim.join([str(x) for x in s]),
2321 delim.join([str(x) for x in p]),
2322 delim.join([str(x) for x in d]),
2323 ),
2324 args_dict,
2325 )
2326 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002327 if (
2328 error_name
Jeremy Johnson87460262024-03-25 09:46:02 +00002329 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002330 ):
2331 # Found enough errors
2332 logger.debug(
2333 f"Skipping creating more conv error tests for {error_name}"
2334 )
2335 more_tests = False
Tai Lyf36f2562024-03-14 16:21:29 +00002336 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002337
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002338 arg_list = TosaArgGen._add_data_generators(
2339 testGen,
2340 opName,
evacha019c96eef2024-02-07 11:21:55 +00002341 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002342 dtypes[0],
2343 arg_list,
2344 error_name,
2345 )
2346 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002347 return arg_list
2348
2349 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002350 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002351
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002352 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002353 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002354
2355 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002356 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002357 elif error_name == ErrorIf.WrongInputType:
2358 # Pick some potentially correct output dtype if input type is incorrect
2359 accum_dtype = DType.INT32
2360 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002361 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002362
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002363 # Set up compliance info
2364 args_dict = {
2365 "acc_type": accum_dtype,
2366 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2367 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2368 "shape": shapeList[0],
2369 }
2370
2371 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2372
2373 arg_list = TosaArgGen._add_data_generators(
2374 testGen,
2375 opName,
evacha019c96eef2024-02-07 11:21:55 +00002376 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002377 input_dtype,
2378 arg_list,
2379 error_name,
2380 )
2381 # Return list of tuples: (arg_str, args_dict)
2382 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002383
2384 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002385 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002386 # Get valid accumulate type(s)
2387 if dtype == DType.INT8:
2388 accum_dtypes = [DType.INT32]
2389 elif dtype == DType.INT16:
2390 accum_dtypes = [DType.INT48]
2391 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002392 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002393 elif dtype == DType.BF16:
2394 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002395 elif dtype == DType.FP32:
2396 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002397 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2398 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002399 elif error_name is None:
2400 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2401
2402 if error_name == ErrorIf.WrongOutputType:
2403 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002404 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002405 elif error_name == ErrorIf.WrongInputType:
2406 # Pick some potentially correct output dtype if input type is incorrect
2407 accum_dtypes = [DType.INT32]
2408
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002409 # Set up compliance info
2410 args_dict = {
2411 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2412 # Set dot_products = N*H*W
2413 "dot_products": gtu.product(
2414 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2415 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002416 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002417 }
2418
2419 # Create arg tuple of string and dict
2420 arg_list = []
2421 for a in accum_dtypes:
2422 d = args_dict.copy()
2423 d["acc_type"] = a
2424 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002425
2426 arg_list = TosaArgGen._add_data_generators(
2427 testGen,
2428 opName,
evacha019c96eef2024-02-07 11:21:55 +00002429 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002430 dtype,
2431 arg_list,
2432 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002433 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002434 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002435 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002436
2437 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002438 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002439 arg_list = []
2440
Jeremy Johnson0c716862023-04-13 17:18:19 +01002441 if testGen.args.level8k and error_name is not None:
2442 # Don't produce negative large tests
2443 return arg_list
2444
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002445 ifm_shape = shapeList[0]
2446 filter_shape = shapeList[1]
2447
Tai Lyf36f2562024-03-14 16:21:29 +00002448 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2449
2450 if error_name == ErrorIf.WrongAccumulatorType:
2451 accum_dtypes = (
2452 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2453 )
James Ward8b390432022-08-12 20:48:56 +01002454
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002455 # Must be rank 4
2456 if error_name != ErrorIf.WrongRank:
2457 assert len(ifm_shape) == 4
2458 assert len(filter_shape) == 4
2459
Jeremy Johnson0c716862023-04-13 17:18:19 +01002460 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002461 # compliance size - KS
2462 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002463
Jeremy Johnson0c716862023-04-13 17:18:19 +01002464 if not testGen.args.level8k:
2465 # Generate comprehensive argument lists
2466 # - except for named errors, which use specific invalid value(s)
2467 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2468 if error_name == ErrorIf.PadLargerEqualKernel:
2469 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002470 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002471 else:
2472 p_vals = [
2473 x
2474 for x in range(
2475 smallest_padding_size, testGen.args.max_conv_padding + 1
2476 )
2477 ]
2478 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2479 if error_name == ErrorIf.StrideSmallerOne:
2480 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002481 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002482 else:
2483 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2484 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002485
Jeremy Johnson0c716862023-04-13 17:18:19 +01002486 if not error_name and testGen.args.oversize:
2487 # add some oversize argument values
2488 if max(ifm_shape) < 64:
2489 bigPadding = 9
2490 paddings.update(
2491 {
2492 x
2493 for x in itertools.product(
2494 *([[smallest_padding_size, bigPadding]] * 4)
2495 )
2496 }
2497 )
2498 bigStride = 8
2499 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2500
2501 # There are too many parameter combinations, so generate them sparsely,
2502 # very sparse for negative tests
2503 sparsity_factor = 2 if error_name else 10
2504 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2505 # If there are only a small number of tests, just select them all
2506 if sparsity < 13:
2507 sparsity = 1
2508 # To get a variety of parameter combinations sparsity should not be a
2509 # multiple of 2, 3 or 5
2510 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2511 sparsity += 1
2512 else:
2513 # Only test 8k levels boundaries
2514 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2515 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2516 bigPadding = bigKernel
2517
2518 pad_shape = [0] * (len(k_shape) * 2)
2519 stride_shape = [1] * len(k_shape)
2520 # The point at which input dimension combined with the stride will
2521 # create large output sizes!
2522 LARGE_SIZE = 2
2523 for idx in range(len(k_shape)):
2524 pad_offset = idx * 2
2525 if k_shape[idx] == bigKernel:
2526 # Set large stride
2527 stride_shape[idx] = bigKernel
2528 # Use negative output padding to reduce shape size
2529 pad_shape[pad_offset] = -(bigPadding - 1)
2530 if ifm_shape[idx + 1] > LARGE_SIZE:
2531 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2532 else:
2533 # The other dimension should be the bigKernel
2534 alt_idx = 1 - idx
2535 if (
2536 k_shape[alt_idx] == bigKernel
2537 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2538 ):
2539 # As the input is small, the large stride won't
2540 # affect the output so we can add some padding
2541 pad_shape[pad_offset + 1] = bigPadding
2542
2543 strides = {tuple(stride_shape)}
2544 paddings = {tuple(pad_shape)}
2545
2546 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002547 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002548
2549 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002550 for a in accum_dtypes:
2551 for s in sorted(list(strides)):
2552 for p in sorted(list(paddings)):
2553 if n % sparsity == 0:
2554 # Determine the output shape
2555 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2556 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2557 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002558
Tai Lyf36f2562024-03-14 16:21:29 +00002559 # N*OH*OW*OC
2560 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2561 args_dict = {
2562 "acc_type": a,
2563 "stride": s,
2564 "pad": p,
2565 "kernel": k_shape,
2566 "ks": k_size,
2567 "dot_products": dots,
2568 "shape": ifm_shape,
2569 "out_shape": os,
2570 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002571
Tai Lyf36f2562024-03-14 16:21:29 +00002572 # Support for larger values than 9 needs different delimiter
2573 delim = "" if max(s + p) <= 9 else "x"
2574 arg_list.append(
2575 (
2576 "acc{}_st{}_pad{}_os{}".format(
2577 testGen.typeStr(a),
2578 delim.join([str(x) for x in s]),
2579 delim.join([str(x) for x in p]),
2580 "x".join([str(x) for x in os]),
2581 ),
2582 args_dict,
2583 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002584 )
Tai Lyf36f2562024-03-14 16:21:29 +00002585 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002586
Jeremy Johnson95a67102024-01-10 14:16:39 +00002587 arg_list = TosaArgGen._add_data_generators(
2588 testGen,
2589 opName,
evacha019c96eef2024-02-07 11:21:55 +00002590 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002591 dtypes[0],
2592 arg_list,
2593 error_name,
2594 )
2595 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002596 return arg_list
2597
2598 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002599 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002600 rank = len(shapeList[0])
2601
Jeremy Johnson30a36842024-03-27 15:04:07 +00002602 if error_name is None and testGen.args.oversize:
2603 pad_values = [6, 7, 10, 13]
2604 elif error_name == ErrorIf.PadSmallerZero:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002605 pad_values = [x for x in range(-2, 0)]
Jeremy Johnson30a36842024-03-27 15:04:07 +00002606 else:
2607 # Exhaustively test combinations of padding on each side of each dimension
2608 # - the range of padding values is defined by pad_min and pad_max
2609 pad_min, pad_max = 0, 1
2610 pad_values = [x for x in range(pad_min, pad_max + 1)]
2611
2612 # Calculate pad combinations
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002613 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2614 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2615
2616 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002617 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002618 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002619 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002620 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002621 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002622 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002623 assert error_name == ErrorIf.WrongInputType
2624 pad_const_int = 0
2625 pad_const_fp = 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002626
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002627 list_shape_pad_values = list(shape_pad_values)
2628 # If we are producing tests for rank 6 or greater use sparsity
2629 if len(list_shape_pad_values) > 1024:
2630 sparsity_factor = 2 if error_name else 120
2631 sparsity = TosaArgGen._calculate_sparsity(
2632 len(list_shape_pad_values), sparsity_factor
2633 )
2634 else:
2635 sparsity = 1
2636
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002637 # Build arg list
2638 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002639 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002640 paddings = list(paddings)
2641 args_valid = True
2642
2643 if error_name == ErrorIf.PadSmallerZero:
2644 # Prevent negative output shapes while ensuring still testing for negative padding
2645 for i in range(rank):
2646 dim_after_padding = (
2647 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2648 )
2649 if dim_after_padding < 1:
2650 paddings[i] = (0, 0)
2651 if all([p > -1 for p in paddings[i]]):
2652 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002653 if args_valid and n % sparsity == 0:
Jeremy Johnson30a36842024-03-27 15:04:07 +00002654 # Work out name
2655 pad_list = []
James Ward8b390432022-08-12 20:48:56 +01002656 for r in range(rank):
Jeremy Johnson30a36842024-03-27 15:04:07 +00002657 pad_list.extend(paddings[r])
2658
2659 delim = "" if max(pad_list) <= 9 else "x"
2660 name = "pad{}".format(delim.join([str(x) for x in pad_list]))
2661
2662 args_dict = {
2663 "pad": np.array(paddings),
2664 "pad_const_int": pad_const_int,
2665 "pad_const_fp": pad_const_fp,
2666 }
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002667 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002668
2669 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002670 logger.debug(
2671 f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
2672 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002673
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002674 arg_list = TosaArgGen._add_data_generators(
2675 testGen,
2676 opName,
evacha019c96eef2024-02-07 11:21:55 +00002677 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002678 dtype,
2679 arg_list,
2680 error_name,
2681 )
2682
2683 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002684 return arg_list
2685
2686 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002687 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002688 arg_list = []
2689
2690 shape = shapeList[0]
2691 if error_name != ErrorIf.WrongRank:
2692 assert len(shape) == 4
2693
Jeremy Johnson0c716862023-04-13 17:18:19 +01002694 test_level8k = testGen.args.level8k and error_name is None
2695
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002696 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002697 startKernel = 2
2698 startPad = 0
2699 if not test_level8k:
2700 # Generate comprehensive argument lists
2701 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2702 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2703 # Stride must be greater than 1 to force non-integer error
2704 s_vals = [
2705 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2706 ]
2707 strides = {x for x in itertools.product(*([s_vals] * 2))}
2708 k_vals = [
2709 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2710 ]
2711 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2712 max_dim_size = None
2713 else:
2714 # Only test 8k levels
2715 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2716 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2717 strides = {(1, bigStride), (bigStride, 4)}
2718 kernels = {(1, bigKernel), (bigKernel, 3)}
2719 paddings = set()
2720 for s in sorted(list(strides)):
2721 for k in sorted(list(kernels)):
2722 padding = []
2723 for idx in range(len(k)):
2724 total_padding = s[idx] - shape[idx + 1] + k[idx]
2725 while total_padding < 0:
2726 # Must meet: shape + padding > kernel
2727 total_padding += s[idx]
2728 if total_padding < k[idx]:
2729 padding.extend([0, total_padding])
2730 else:
2731 # Note this may produce padding >= k[idx] which is not
2732 # allowed - but will be ignored in the creation loop below
2733 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2734 paddings.add(tuple(padding))
2735 # Create a limit for the output dimensions size
2736 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002737
James Ward8b390432022-08-12 20:48:56 +01002738 if opName == "max_pool2d":
2739 accum_dtypes = [None] # max_pool has no accumulate dtype
2740 elif dtype == DType.INT8 or dtype == DType.INT16:
2741 accum_dtypes = [DType.INT32]
2742 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002743 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002744 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002745 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002746 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2747 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002748 elif error_name is None:
2749 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2750 else:
2751 # Set to something for the ErrorIf case which has
2752 # incorrect input data-type
2753 accum_dtypes = [DType.INT32]
2754
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002755 if error_name == ErrorIf.WrongAccumulatorType:
2756 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2757
Jeremy Johnson0c716862023-04-13 17:18:19 +01002758 if not test_level8k:
2759 if testGen.args.oversize:
2760 # add some oversize argument values
2761 bigStride = 7
2762 bigKernel = 9
2763 strides.update(
2764 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002765 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002766 kernels.update(
2767 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2768 )
2769 if max(shape) < 64:
2770 # padding must be less than the kernel size
2771 bigPadding = bigKernel - 1
2772 paddings.update(
2773 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2774 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002775
Jeremy Johnson87460262024-03-25 09:46:02 +00002776 if error_name:
2777 # Cycle through all error_if tests but we only keep the first few
2778 sparsity = 1
2779 else:
2780 # There are too many parameter combinations, so generate them sparsely
2781 sparsity_factor = 500
2782 sparsity = (
2783 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2784 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002785 else:
2786 # We have already limited test output combinations for 8k tests
2787 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002788
James Ward8b390432022-08-12 20:48:56 +01002789 arg_str = (
2790 "acc{}_st{}_kern{}_pad{}"
2791 if accum_dtypes[0] is not None
2792 else "st{}_kern{}_pad{}"
2793 )
2794
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002795 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002796 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002797 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002798
2799 # Support for larger values than 9 needs different delimiter
2800 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002801 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002802 delim.join([str(x) for x in stride]),
2803 delim.join([str(x) for x in kern]),
2804 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002805 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002806 args_dict = {
2807 "stride": stride,
2808 "pad": pad,
2809 "kernel": kern,
2810 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002811 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002812 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2813 }
James Ward8b390432022-08-12 20:48:56 +01002814
2815 if accum is not None:
2816 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002817 args_dict["acc_type"] = accum
2818 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002819
Jeremy Johnson87460262024-03-25 09:46:02 +00002820 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002821 n = 0
James Ward8b390432022-08-12 20:48:56 +01002822 for a in accum_dtypes:
2823 for s in sorted(list(strides)):
2824 for p in sorted(list(paddings)):
2825 for k in sorted(list(kernels)):
2826 if error_name in [
2827 ErrorIf.StrideSmallerOne,
2828 ErrorIf.KernelSmallerOne,
2829 ErrorIf.PadSmallerZero,
2830 ErrorIf.PadLargerEqualKernel,
2831 ]:
2832 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002833 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002834 )
James Ward8b390432022-08-12 20:48:56 +01002835 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002836 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002837 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002838 )
James Ward8b390432022-08-12 20:48:56 +01002839 elif (
Jeremy Johnson87460262024-03-25 09:46:02 +00002840 more_tests
2841 and n % sparsity == 0
James Ward8b390432022-08-12 20:48:56 +01002842 # padding must not exceed the kernel size
2843 and p[0] < k[0]
2844 and p[1] < k[0]
2845 and p[2] < k[1]
2846 and p[3] < k[1]
2847 # the padded shape must exceed the kernel size
2848 and (shape[1] + p[0] + p[1]) > k[0]
2849 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002850 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002851 partial_h = shape[1] + p[0] + p[1] - k[0]
2852 partial_w = shape[2] + p[2] + p[3] - k[1]
2853 remainder_h = partial_h % s[0]
2854 remainder_w = partial_w % s[1]
2855 output_h = partial_h // s[0] + 1
2856 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002857 logger.debug(
2858 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2859 )
James Ward8b390432022-08-12 20:48:56 +01002860 if (
2861 # the parameters must produce integer exact output
2862 error_name != ErrorIf.PoolingOutputShapeNonInteger
2863 and remainder_h == 0
2864 and remainder_w == 0
2865 ) or (
2866 error_name == ErrorIf.PoolingOutputShapeNonInteger
2867 and (remainder_h != 0 or remainder_w != 0)
2868 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002869 if (
2870 max_dim_size is not None
2871 and max(output_h, output_w) > max_dim_size
2872 ):
2873 # Test will consume too much memory - skip it
2874 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002875 # Dot products = N*OH*OW*C
2876 dp = gtu.product(
2877 (shape[0], output_h, output_w, shape[3])
2878 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002879 arg_list.append(
2880 get_arg_list_element(a, s, p, k, dp, shape)
2881 )
Jeremy Johnson87460262024-03-25 09:46:02 +00002882 if (
2883 error_name
2884 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
2885 ):
2886 # Found enough errors
2887 logger.debug(
2888 f"Skipping creating more pooling error tests for {error_name}"
2889 )
2890 more_tests = False
2891
James Ward8b390432022-08-12 20:48:56 +01002892 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002893
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002894 # Now add data generator types
2895 arg_list = TosaArgGen._add_data_generators(
2896 testGen,
2897 opName,
evacha019c96eef2024-02-07 11:21:55 +00002898 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002899 dtype,
2900 arg_list,
2901 error_name,
2902 )
2903
2904 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002905 return arg_list
2906
2907 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002908 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002909 arg_list = []
2910
2911 # Enumerate the output types here
2912 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002913 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002914 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002915 dtypeList = [
2916 DType.BOOL,
2917 DType.INT16,
2918 DType.INT32,
2919 DType.FP16,
2920 DType.BF16,
2921 DType.FP32,
2922 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002923 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002924 dtypeList = [
2925 DType.BOOL,
2926 DType.INT8,
2927 DType.INT32,
2928 DType.FP16,
2929 DType.BF16,
2930 DType.FP32,
2931 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002932 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002933 dtypeList = [
2934 DType.BOOL,
2935 DType.INT8,
2936 DType.INT16,
2937 DType.FP16,
2938 DType.BF16,
2939 DType.FP32,
2940 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002941 elif inDtype == DType.BOOL:
2942 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002943 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002944 dtypeList = [
2945 DType.INT8,
2946 DType.INT16,
2947 DType.INT32,
2948 DType.FP32,
2949 DType.FP8E4M3,
2950 DType.FP8E5M2,
2951 ]
James Ward24dbc422022-10-19 12:20:31 +01002952 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002953 dtypeList = [
2954 DType.INT8,
2955 DType.INT16,
2956 DType.INT32,
2957 DType.FP32,
2958 DType.FP8E4M3,
2959 DType.FP8E5M2,
2960 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002961 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002962 dtypeList = [
2963 DType.INT8,
2964 DType.INT16,
2965 DType.INT32,
2966 DType.FP16,
2967 DType.BF16,
2968 DType.FP8E4M3,
2969 DType.FP8E5M2,
2970 ]
2971 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2972 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002973 elif error_name == ErrorIf.WrongInputType:
2974 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002975 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002976 else:
2977 raise Exception("Unexpected input dtype: {}".format(inDtype))
2978
2979 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002980 arg_list.append(
2981 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2982 )
2983
2984 # Now add data generator types
2985 arg_list = TosaArgGen._add_data_generators(
2986 testGen,
2987 opName,
evacha019c96eef2024-02-07 11:21:55 +00002988 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002989 dtype,
2990 arg_list,
2991 error_name,
2992 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002993
2994 return arg_list
2995
2996 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002997 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002998 arg_list = []
2999
3000 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003001 for outDtype in [
3002 DType.UINT8,
3003 DType.INT8,
3004 DType.INT16,
3005 DType.INT32,
3006 DType.UINT16,
3007 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003008 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003009 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003010 and error_name == ErrorIf.OutputZeroPointNotZero
3011 ):
3012 continue
3013 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003014 outDtype != DType.UINT16
3015 and error_name == ErrorIf.U16OutputZeroPointNotValid
3016 ) or (
3017 inDtype != DType.UINT16
3018 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003019 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003020 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003021 continue
3022 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003023 inDtype == DType.UINT8
3024 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003025 and error_name != ErrorIf.WrongOutputType
3026 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003027 # The only output dtypes for UINT8 are INT8/INT16, skip all others
3028 continue
3029 if (
3030 inDtype not in [DType.INT8, DType.INT16]
3031 and outDtype == DType.UINT8
3032 and error_name != ErrorIf.WrongOutputType
3033 ):
3034 # The only input dtypes for UINT8 are INT8/INT16, skip all others
3035 continue
3036 if (
3037 inDtype == DType.UINT16
3038 and outDtype != DType.INT16
3039 and error_name != ErrorIf.WrongOutputType
3040 ):
3041 # The only output dtype for UINT16 is INT16, skip all others
3042 continue
3043 if (
3044 inDtype != DType.INT16
3045 and outDtype == DType.UINT16
3046 and error_name != ErrorIf.WrongOutputType
3047 ):
3048 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003049 continue
3050 if (
3051 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003052 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003053 ):
3054 continue
3055
3056 for scale32 in [False, True]:
3057 if error_name == ErrorIf.ScaleTrue and not scale32:
3058 continue
3059 elif error_name == ErrorIf.ScaleNotTrue and scale32:
3060 continue
3061 for double_round in [False, True]:
3062 if error_name == ErrorIf.ScaleNotTrue and not double_round:
3063 continue
Jeremy Johnson18a379d2024-03-28 15:53:21 +00003064 # Per_channel is only valid with rank > 0
3065 pc_options = (False, True) if len(shapeList[0]) > 0 else (False,)
3066 for per_channel in pc_options:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003067
3068 if (
3069 inDtype == DType.INT48
3070 and scale32
3071 and error_name != ErrorIf.ScaleTrue
3072 ):
3073 # Illegal condition. Must be scale32=False
3074 continue
3075 if (
3076 double_round
3077 and not scale32
3078 and error_name != ErrorIf.ScaleNotTrue
3079 ):
3080 # Illegal condition. ERROR_IF(!scale32 && double_round)
3081 continue
3082
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003083 if per_channel:
3084 nc = shapeList[0][-1]
3085 else:
3086 nc = 1
3087
3088 in_type_width = gtu.dtypeWidth(inDtype)
3089 out_type_width = gtu.dtypeWidth(outDtype)
3090
3091 # Calculate scale based on:
3092 # scale = a *(2^output_width)/(2^input_width))
3093
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003094 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003095 scale_arr = a * np.float32(
3096 (1 << out_type_width) / (1 << in_type_width)
3097 )
3098
3099 if scale32:
3100 # Cap the scaling at 2^31 - 1 for scale32
3101 scale_arr = np.clip(
3102 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
3103 )
3104 else:
3105 # Cap the scaling at 2^15 - 1 for scale16
3106 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3107
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003108 logger.debug(
3109 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
3110 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003111
3112 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3113 shift_arr = np.int32(np.zeros(shape=[nc]))
3114 for i in range(nc):
3115 (
3116 multiplier_arr[i],
3117 shift_arr[i],
3118 ) = TosaQuantGen.computeMultiplierAndShift(
3119 scale_arr[i], scale32
3120 )
3121
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003122 arg_list.append(
3123 (
3124 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01003125 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003126 int(scale32),
3127 int(double_round),
3128 int(per_channel),
3129 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00003130 {
3131 "output_dtype": outDtype,
3132 "scale": scale32,
3133 "double_round": double_round,
3134 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003135 "multiplier": multiplier_arr,
3136 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003137 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003138 )
3139 )
3140
Jeremy Johnson587cc842024-02-08 11:45:44 +00003141 arg_list = TosaArgGen._add_data_generators(
3142 testGen,
3143 opName,
evacha019c96eef2024-02-07 11:21:55 +00003144 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003145 inDtype,
3146 arg_list,
3147 error_name,
3148 )
3149 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003150 return arg_list
3151
3152 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003153 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003154 arg_list = []
3155
3156 if dtype is DType.INT32:
3157 for p in range(testGen.args.num_rand_permutations):
3158
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003159 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003160 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003161 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003162 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003163
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003164 arg_list = TosaArgGen._add_data_generators(
3165 testGen,
3166 opName,
evacha019c96eef2024-02-07 11:21:55 +00003167 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003168 dtype,
3169 arg_list,
3170 error_name,
3171 )
3172 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003173 return arg_list
3174
3175 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003176 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003177 arg_list = []
3178
Jeremy Johnson587cc842024-02-08 11:45:44 +00003179 for round in (True, False):
3180 args_dict = {
3181 "round": round,
3182 }
3183 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003184
Jeremy Johnson587cc842024-02-08 11:45:44 +00003185 arg_list = TosaArgGen._add_data_generators(
3186 testGen,
3187 opName,
evacha019c96eef2024-02-07 11:21:55 +00003188 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003189 dtype,
3190 arg_list,
3191 error_name,
3192 )
3193 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003194 return arg_list
3195
Luke Hutton57287132023-02-06 14:54:18 +00003196 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003197 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003198 arg_list = []
3199
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003200 shape = shapeList[0]
3201 dot_products = gtu.product(shape)
3202 ks = 2 * shape[1] * shape[2] # 2*H*W
3203 for inverse in (True, False):
3204 args_dict = {
3205 "dot_products": dot_products,
3206 "shape": shape,
3207 "ks": ks,
3208 "acc_type": dtype,
3209 "inverse": inverse,
3210 }
3211 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003212
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003213 arg_list = TosaArgGen._add_data_generators(
3214 testGen,
3215 opName,
evacha019c96eef2024-02-07 11:21:55 +00003216 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003217 dtype,
3218 arg_list,
3219 error_name,
3220 )
3221 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003222 return arg_list
3223
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003224 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003225 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003226 arg_list = []
3227
3228 shape = shapeList[0]
3229 dot_products = gtu.product(shape)
3230 ks = shape[1] * shape[2] # H*W
3231 args_dict = {
3232 "dot_products": dot_products,
3233 "shape": shape,
3234 "ks": ks,
3235 "acc_type": dtype,
3236 }
3237 arg_list.append(("", args_dict))
3238
3239 arg_list = TosaArgGen._add_data_generators(
3240 testGen,
3241 opName,
evacha019c96eef2024-02-07 11:21:55 +00003242 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003243 dtype,
3244 arg_list,
3245 error_name,
3246 )
3247 # Return list of tuples: (arg_str, args_dict)
3248 return arg_list
3249
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003250 # Helper function for reshape. Gets some factors of a larger number.
3251 @staticmethod
3252 def getFactors(val, start=1):
3253 factors = []
3254
3255 for i in range(start, int(np.sqrt(val)) + 1):
3256 if (val % i) == 0:
3257 factors.append(i)
3258
3259 return factors
3260
3261 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003262 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003263 arg_list = []
3264
3265 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003266 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003267 factors = TosaArgGen.getFactors(totalElements)
3268
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003269 # Find new shapes up to the number of permutations asked for
3270 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003271 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003272 # Rank from 1 to MAX_TENSOR_RANK
3273 newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003274 if len(factors) < newRank:
3275 continue
3276
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003277 # escape_counter limits the generation of new shapes to a reasonable time
3278 for escape_counter in range(100):
3279
3280 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003281 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003282 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003283 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 for i in range(1, newRank):
3285 # pick rank-1 factors
3286 newShape.append(shuffledFactors[0])
3287 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003288 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 TosaArgGen.getFactors(remainingElements)
3290 )
3291 newShape.append(remainingElements)
3292
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003293 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003294 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003295 for name, args_dict in arg_list:
3296 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003297 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003298 break
3299
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003300 if not duplicate:
3301 outShape = "x".join([str(x) for x in newShape])
3302 arg_list.append(
3303 (
3304 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3305 {"new_shape": newShape},
3306 )
3307 )
3308 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003309 break
3310
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003311 # Now add data generator types
3312 arg_list = TosaArgGen._add_data_generators(
3313 testGen,
3314 opName,
evacha019c96eef2024-02-07 11:21:55 +00003315 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003316 dtype,
3317 arg_list,
3318 error_name,
3319 )
3320
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003321 return arg_list
3322
3323 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003324 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003325 arg_list = []
3326
3327 ifm_shape = shapeList[0]
3328
3329 if error_name == ErrorIf.IndexOutsideBounds:
3330 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3331 incorrect_small_index = range(-len(ifm_shape), 0)
3332 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3333 permutations.extend(
3334 [p for p in itertools.permutations(incorrect_small_index)]
3335 )
3336 elif error_name == ErrorIf.IndexUsedTwice:
3337 # Create list with a duplicated index
3338 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003339 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003340 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3341 permutations = [p for p in itertools.permutations(perm_range)]
3342
3343 else:
3344 # Get all permutations
3345 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3346
3347 # Limit to possible permutations from shape dimension or argument setting
3348 limit = min(len(permutations), testGen.args.num_rand_permutations)
3349
3350 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003351 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003352
3353 # Create list of required amount of permutations
3354 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003355 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003356 for p in range(limit)
3357 ]
evacha0198477222024-01-26 12:25:32 +00003358 # Now add data generator types
3359 arg_list = TosaArgGen._add_data_generators(
3360 testGen,
3361 opName,
evacha019c96eef2024-02-07 11:21:55 +00003362 shapeList,
evacha0198477222024-01-26 12:25:32 +00003363 dtype,
3364 arg_list,
3365 error_name,
3366 )
3367 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003368 return arg_list
3369
3370 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003371 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003372 arg_list = []
3373
3374 ifm_shape = shapeList[0]
3375 rank = len(ifm_shape)
3376
3377 for p in range(testGen.args.num_rand_permutations):
3378 start = []
3379 size = []
3380
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003381 for i in range(rank):
3382 if ifm_shape[i] > 1:
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003383 # Start from 0 to dimension size - 1 to leave room for slice of 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003384 start.append(rng.randInt(0, ifm_shape[i]))
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003385 # Size from 1 up to rest of room (dimension size - start)
3386 size.append(rng.randInt(1, ifm_shape[i] + 1 - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003387
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003388 # Should never hit an invalid slice size
3389 assert size[i] > 0 and (size[i] + start[i]) <= ifm_shape[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003390 else:
3391 start.append(0)
3392 size.append(1)
3393
Jeremy Johnson3f3de012024-04-08 15:18:05 +01003394 # If ERROR_IF test required then incorrect start, size will be returned
3395 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
3396 rng, error_name, ifm_shape, start, size
3397 )
3398 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3399
evacha017f7d4252024-01-24 12:08:09 +00003400 # Now add data generator types
3401 arg_list = TosaArgGen._add_data_generators(
3402 testGen,
3403 opName,
evacha019c96eef2024-02-07 11:21:55 +00003404 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003405 dtype,
3406 arg_list,
3407 error_name,
3408 )
3409 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003410 return arg_list
3411
3412 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003413 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003414 arg_list = []
3415
3416 ifm_shape = shapeList[0]
3417 rank = len(ifm_shape)
3418
3419 for p in range(testGen.args.num_rand_permutations):
3420
3421 # Pick a few random, but small multiple values
3422 # because otherwise this has a tendency to generate
3423 # enormous tensors
3424 multiples = []
3425 for i in range(rank):
3426 if ifm_shape[i] > 1000:
3427 # Multiple of 1 if ifm_shape dimension is large to reduce
3428 # tensor size
3429 multiples.append(1)
3430 elif max(ifm_shape) > 1000:
3431 multiples.append(2)
3432 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003433 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003434 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003435
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003436 # Now add data generator types
3437 arg_list = TosaArgGen._add_data_generators(
3438 testGen,
3439 opName,
evacha019c96eef2024-02-07 11:21:55 +00003440 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003441 dtype,
3442 arg_list,
3443 error_name,
3444 )
3445 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003446 return arg_list
3447
3448 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003449 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003450 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003451 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003452
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003453 def get_aspect_ratio_resize_params():
3454 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003455 aspect_ratio = rng.choice(common_aspect_ratios)
3456 invert = rng.choice((False, True))
3457 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003458
3459 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3460 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3461 scale_y_d = scale_x_d = 1
3462 offset_x = offset_y = 0
3463
3464 if letterbox:
3465 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003466 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003467 border_x = 0
3468 else:
3469 # Pillarboxing
3470 border_y = 0
3471 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003472 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003473
3474 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3475 offset = (offset_y, offset_x)
3476 border = (border_y, border_x)
3477
3478 return scale, offset, border
3479
3480 def get_upscale_downscale_params():
3481 valid_params = False
3482 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003483 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003484
3485 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003486 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003487
3488 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003489 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003490 scale_x_d = scale_y_d = 1
3491 scale_x_n = scale_y_n = (
3492 1 << shift if origin_sampling else 2 << shift
3493 )
3494 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3495 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3496 else:
3497 scale_x_n = 1
3498 scale_y_n = 1
3499
3500 # Return list of valid scale_*_d values (max value 4) given input dim shape
3501 def get_valid_denom(ifm_dim):
3502 return [x for x in range(1, 5) if ifm_dim % x == 1]
3503
3504 # Generate list of valid downscale values and choose one randomly
3505 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3506 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3507
3508 if not valid_scale_y_ds and not valid_scale_x_ds:
3509 # Bad parameters, skip
3510 continue
3511
3512 if not valid_scale_y_ds:
3513 scale_y_d = 1
3514 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003515 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003516
3517 if not valid_scale_x_ds:
3518 scale_x_d = 1
3519 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003520 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003521
3522 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003523 offset_y = rng.randInt(0, 16 * scale_y_n)
3524 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003525 valid_params = True
3526
3527 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3528 offset = (offset_y, offset_x)
3529 border = (border_y, border_x)
3530 return scale, offset, border
3531
3532 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003533 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3534 scale = scale_n / scale_d
3535 if scale > max_scale:
3536 factor = scale / max_scale
3537 new_scale_d = math.ceil(scale_d * factor)
3538 assert scale_n / new_scale_d <= max_scale
3539 scale_d = new_scale_d
3540 return scale_d
3541
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003542 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003543 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3544 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003545
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003546 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3547 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003548
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003549 scale_y_d = fix_scale_to_max_scale(
3550 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3551 )
3552 scale_x_d = fix_scale_to_max_scale(
3553 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3554 )
3555
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003556 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003557 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3558 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3559 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3560 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003561
3562 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3563 offset = (offset_y, offset_x)
3564 border = (border_y, border_x)
3565 return scale, offset, border
3566
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003567 def get_level_8k_params():
3568 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003569 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003570 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3571 )
3572 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3573 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003574 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003575 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003576 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003577 if switch:
3578 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3579 else:
3580 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3581
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003582 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3583 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003584 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003585 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3586 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003587 border = (border_y, border_x)
3588 return scale, offset, border
3589
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003590 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003591 # Exclude illegal {mode, type} configurations. Pick legal output types
3592 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3593 outputDTypeList = [DType.INT8]
3594 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3595 outputDTypeList = [DType.INT16]
3596 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3597 outputDTypeList = [DType.INT32]
3598 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3599 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003600 elif dtype == DType.FP16:
3601 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003602 elif dtype == DType.BF16:
3603 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003604 elif dtype == DType.FP32:
3605 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003606 elif dtype == DType.FP8E4M3:
3607 outputDTypeList = [DType.FP8E4M3]
3608 elif dtype == DType.FP8E5M2:
3609 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003610 elif error_name == ErrorIf.WrongInputType:
3611 # If an incorrect input type is used then we set a 'correct'
3612 # output type to avoid other errors
3613 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3614 else:
3615 continue
3616
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003617 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3618
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003619 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003620 perm = 0
3621 while perm < testGen.args.num_rand_permutations:
3622 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003623 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003624 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003625 (
3626 get_rand_params,
3627 get_upscale_downscale_params,
3628 get_aspect_ratio_resize_params,
3629 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003630 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003631 scale, offset, border = _rnd_param_fn()
3632 else:
3633 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003634
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003635 # Expand params for bounds-checking
3636 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3637 (offset_y, offset_x) = offset
3638 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003640 # Make sure output dimensions OH and OW are integers
3641 partial_output_y = (
3642 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3643 )
3644 partial_output_x = (
3645 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3646 )
3647 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003648 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003649 if (
3650 partial_output_y % scale_y_d == 0
3651 and partial_output_x % scale_x_d == 0
3652 ):
3653 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003654 if perm > 0:
3655 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003656 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003657 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003658 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003659 while partial_output_y % scale_y_d != 0:
3660 scale_y_d -= 1
3661 while partial_output_x % scale_x_d != 0:
3662 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003663 # Make sure we are still within max scaling
3664 if (
3665 scale_y_n / scale_y_d
3666 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3667 scale_x_n / scale_x_d
3668 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3669 # Skip the test as it is using too large a scaling factor
3670 if perm > 0:
3671 perm += 1
3672 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003673
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003674 output_y = partial_output_y // scale_y_d + 1
3675 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003676
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003677 if (
3678 output_y >= testGen.args.max_resize_output_dim
3679 or output_x >= testGen.args.max_resize_output_dim
3680 ) and error_name is None:
3681 # Skip positive test if output dim will be too high
3682 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003683 if not testGen.args.level8k or perm > 0:
3684 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003685 continue
3686
3687 if (
3688 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003689 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003690 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003691 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003692 ):
3693 # Output dimensions out of scope
3694 if error_name is not None and perm > 0:
3695 # As long as we have one ERROR_IF test, don't worry
3696 # about creating all the other permutations
3697 perm += 1
3698 continue
3699
3700 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3701 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003702 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003703 and output_y - scale_y_d < 1
3704 )
3705 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003706 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003707 and output_x - scale_x_d < 1
3708 )
3709 ):
3710 # Can't create a negative test with these params as it
3711 # will create invalid output size
3712 if perm > 0:
3713 perm += 1
3714 continue
3715
3716 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3717 offset = [offset_y, offset_x]
3718 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003719
3720 # Common for all data types
3721 if error_name is not None:
3722 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003723 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003724 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003725 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003726 outputDTypeNew,
3727 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003728 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003729 error_name,
3730 mode,
3731 dtype,
3732 shapeList,
3733 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003734 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003735 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003736 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003737 )
3738 else:
3739 outputDTypeNew = outputDType
3740
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003741 arg_to_append = (
3742 arg_str.format(
3743 "N" if mode == ResizeMode.NEAREST else "B",
3744 testGen.typeStr(outputDTypeNew),
3745 scale[0],
3746 scale[1],
3747 scale[2],
3748 scale[3],
3749 offset[0],
3750 offset[1],
3751 border[0],
3752 border[1],
3753 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003754 {
3755 "mode": mode,
3756 "scale": scale,
3757 "offset": offset,
3758 "border": border,
3759 "output_dtype": outputDTypeNew,
3760 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003761 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003762 if arg_to_append in arg_list:
3763 # Skip already generated test params
3764 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003765
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003766 # Valid permutation
3767 perm += 1
3768 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003769
3770 # Now add data generator types
3771 arg_list = TosaArgGen._add_data_generators(
3772 testGen,
3773 opName,
evacha019c96eef2024-02-07 11:21:55 +00003774 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003775 dtype,
3776 arg_list,
3777 error_name,
3778 )
3779 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003780 return arg_list
3781
3782 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003783 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 arg_list = []
3785
3786 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003787 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003788 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003789 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003790 # Make sure all slopes are within REQUIRE min/max 16-bit int
3791 for idx in range(len(table) - 1):
3792 slope = table[idx + 1] - table[idx]
3793 # Alter the next table entry to force the slope to be ok
3794 if slope > 32767:
3795 table[idx + 1] -= slope - 32767
3796 if slope < -32768:
3797 table[idx + 1] -= slope + 32768
3798 slope = table[idx + 1] - table[idx]
3799 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003800 arg_list.append(
3801 (
3802 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003803 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003804 )
3805 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003806 # Now add data generator types
3807 arg_list = TosaArgGen._add_data_generators(
3808 testGen,
3809 opName,
evacha019c96eef2024-02-07 11:21:55 +00003810 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003811 dtype,
3812 arg_list,
3813 error_name,
3814 )
3815 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003816 return arg_list
3817
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003818 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003819 # CondIf generates the condition values here.
3820 # Convert to tensors in the build function, along with the
3821 # then and else blocks
3822 arg_list = []
3823
3824 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003825 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003826
Jeremy Johnson587cc842024-02-08 11:45:44 +00003827 # Now add data generator types
3828 arg_list = TosaArgGen._add_data_generators(
3829 testGen,
3830 opName,
evacha019c96eef2024-02-07 11:21:55 +00003831 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003832 dtype,
3833 arg_list,
3834 error_name,
3835 )
3836 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003837 return arg_list
3838
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003839 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003840 # While loop: 0 iterations, 1, more than 1
3841 arg_list = []
3842
Jeremy Johnson587cc842024-02-08 11:45:44 +00003843 for iterations in [0, 1, 4]:
3844 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003845
Jeremy Johnson587cc842024-02-08 11:45:44 +00003846 # Now add data generator types
3847 arg_list = TosaArgGen._add_data_generators(
3848 testGen,
3849 opName,
evacha019c96eef2024-02-07 11:21:55 +00003850 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003851 dtype,
3852 arg_list,
3853 error_name,
3854 )
3855 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003856 return arg_list