blob: 48787088482069b8d7f329f617b16df48e0fe33a [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
262 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100263 shape_list = []
264
265 # Choose one of the inputs to broadcast
266 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100267 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
268 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000269
Jeremy Johnson0a042992024-02-28 13:20:05 +0000270 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100271 shape_bcast = shape.copy()
272
Jerry Ge135c9552023-05-23 20:59:32 +0000273 # To test broadcasting, the chosen fuzz index dimension should not be 1
274 if shape_bcast[fuzz_idx] == 1:
275 shape_bcast[fuzz_idx] += 1
276
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100277 # If the chosen input, pick a random index to broadcast
278 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000279 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100280 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100281 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100282 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100283 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 )
285 if rank != 1:
286 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100287 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000289 elif error_name == ErrorIf.BroadcastShapesMismatch:
290 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 else:
292 shape_bcast[fuzz_idx] = 1
293
294 shape_list.append(shape_bcast)
295
296 return shape_list
297
298 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100299 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000300 pl, const = op["operands"]
301 num_shapes = pl + const
302 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100303 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000304 )
305
306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100307 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000308 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 shape_list = TosaTensorGen._get_broadcast_shapes(
310 testGen, rng, 2, rank, error_name
311 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000312 # Add a single dimension tensor for shift
313 shape_list.append([1])
314 return shape_list
315
316 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100317 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318 pl, const = op["operands"]
319
320 if error_name != ErrorIf.WrongRank:
321 assert rank == 4
322
323 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000325 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326
327 # Constrict the overall size of the shape when creating ERROR_IF tests
328 if error_name:
329 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
330 ifm_shape, max_dim=24, max_items=10000
331 )
332
333 # Get the filter height/width from the operator parameters
334 filter_hw = op["filter"]
335
336 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100337 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # The filter dimensions are OHWI
340 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
341
342 # The bias is OC
343 bias_shape = np.asarray([ofm_depth])
344
345 return [ifm_shape, filter_shape, bias_shape]
346
347 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100348 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 pl, const = op["operands"]
350
351 if error_name != ErrorIf.WrongRank:
352 assert rank == 5
353
354 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100355 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000356 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100357
358 # Constrict the overall size of the shape when creating ERROR_IF tests
359 if error_name:
360 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
361 ifm_shape, max_dim=24, max_items=10000
362 )
363
364 # Get the filter depth/height/width from the operator parameters
365 filter_dhw = op["filter"]
366
367 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100368 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100369
370 # The filter dimensions are ODHWI
371 filter_shape = np.asarray(
372 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
373 )
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_channel])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100381 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386
387 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100388 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000389 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100390
391 # Constrict the overall size of the shape when creating ERROR_IF tests
392 if error_name:
393 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
394 ifm_shape, max_dim=24, max_items=10000
395 )
396
397 # Get the filter height/width from the operator parameters
398 filter_hw = op["filter"]
399
400 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100401 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100402
403 # The filter dimensions are OHWI
404 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
405
406 # The bias is OC
407 bias_shape = np.asarray([ofm_depth])
408
409 return [ifm_shape, filter_shape, bias_shape]
410
411 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100412 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100413 pl, const = op["operands"]
414
415 if error_name != ErrorIf.WrongRank:
416 assert rank == 4
417 assert pl == 1 and const == 2
418
419 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100420 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000421 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
426 ifm_shape, max_dim=24, max_items=10000
427 )
428
429 # Get the filter height/width from the operator parameters
430 # Filter is KH, HW, C, M
431 filter_hw = op["filter"]
432
433 # Generate a random OFM depth, but don't let it get too big because
434 # the output depth is M * C
435 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100436 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100437 ) + 1
438
439 # The filter dimensions are HWCM
440 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
441
442 # The bias is M * C
443 bias_shape = np.asarray([ifm_shape[3] * filter_m])
444
445 return [ifm_shape, filter_shape, bias_shape]
446
447 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100448 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 2 and const == 0
454
455 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100456 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 inc_h = 2 if ifm_shape[1] == 1 else 1
469 inc_w = 2 if ifm_shape[2] == 1 else 1
470 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100471 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000472 ifm_shape[1] += selected_inc[0]
473 ifm_shape[2] += selected_inc[1]
474
475 ifm_shape = testGen.constrictBatchSize(ifm_shape)
476
477 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
478 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100479 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000480 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100481 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000482 ifm_shapes[modify_shape][modify_dim] *= 2
483
484 return [ifm_shapes[0], ifm_shapes[1]]
485
486 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100487 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000488 pl, const = op["operands"]
489
490 if error_name != ErrorIf.WrongRank:
491 assert rank == 3
492 assert pl == 1 and const == 0
493
494 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100495 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000496
497 # Select nearest lower power of two from input height and width
498 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
499 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
500
501 # Constrict the overall size of the shape when creating ERROR_IF tests
502 if error_name:
503 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
504
505 # Generate an invalid kernel that is not a power of two
506 if error_name == ErrorIf.KernelNotPowerOfTwo:
507 # We must increment by 2 if current size is 1
508 inc_h = 2 if ifm_shape[1] == 1 else 1
509 inc_w = 2 if ifm_shape[2] == 1 else 1
510 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100511 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000512 ifm_shape[1] += selected_inc[0]
513 ifm_shape[2] += selected_inc[1]
514
James Ward30124a82023-02-02 14:56:33 +0000515 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000516
517 return [ifm_shape]
518
519 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100520 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100521 pl, const = op["operands"]
522
523 if error_name != ErrorIf.WrongRank:
524 assert rank == 2
525
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527
528 # Constrict the overall size of the shape when creating ERROR_IF tests
529 if error_name:
530 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
531
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100532 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100533 low=testGen.args.tensor_shape_range[0],
534 high=testGen.args.tensor_shape_range[1],
535 size=1,
536 )[0]
537 filter_shape = np.asarray([filter_oc, input_shape[1]])
538
539 bias_shape = np.asarray([filter_oc])
540
541 return [input_shape, filter_shape, bias_shape]
542
543 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100544 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100545 pl, const = op["operands"]
546
547 if error_name != ErrorIf.WrongRank:
548 assert rank == 3
549 assert pl == 2 and const == 0
550
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100551 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100552
553 # Constrict the overall size of the shape when creating ERROR_IF tests
554 if error_name:
555 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
556
557 # Get a random number for b_oc even if target shape is defined
558 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100559 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100560 low=testGen.args.tensor_shape_range[0],
561 high=testGen.args.tensor_shape_range[1],
562 size=1,
563 )
564 )[0]
565 # If N or H is large let b_oc be 1 to reduce output tensor size
566 if max(a_shape) > 1000:
567 b_oc = 1
568
569 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
570 return [a_shape, b_shape]
571
572 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100573 def tgConcat(testGen, rng, op, rank, error_name=None):
574 pl, const = op["operands"]
575 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100576
577 # Create extra tensors to concat.
578 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100579 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100580 shape_list = []
581 for i in range(pl + const + num_tensors):
582 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100583 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100584 wrongShape = shape.copy()
585
586 if remove and len(shape) > 1:
587 wrongShape = wrongShape[1:]
588 else:
589 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100590 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100591
592 shape_list.append(wrongShape)
593 else:
594 shape_list.append(shape.copy())
595
596 return shape_list
597
598 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100599 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100600 if error_name in [
601 ErrorIf.AxisSmallerZero,
602 ErrorIf.AxisLargerRank,
603 ErrorIf.ConcatInputRankMismatch,
604 ]:
605 return shapeList
606
607 # Split concat shape along axis to allow for multiple const inputs
608 # without making too many large tensors
609 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
610 # If axis can't be split we still need to invalidate other dimensions
611 if error_name == ErrorIf.ConcatInputDimMismatch:
612 for shape in shapeList[1:]:
613 # Negative test shapeLists are created individually for each test,
614 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100615 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100616 return shapeList
617
618 # Create copy of shape we are going to split (so we don't alter shapeList)
619 shape = shapeList[0].copy()
620 # Add original shape as first input
621 new_shapeList = [shape.copy()]
622 length_on_axis = shape[axis]
623 remaining_length = length_on_axis
624 for i in range(len(shapeList) - 2):
625 # Calculate split on axis and remaining value
626 split_shape_val = int(shape[axis] / 2)
627 remaining_length = remaining_length - split_shape_val
628
629 # Append new shape, and set remaining shape
630 shape[axis] = split_shape_val
631 new_shapeList.append(shape.copy())
632
633 # invalidate dimensions
634 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100635 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100636 else:
637 shape[axis] = remaining_length
638
639 if i == len(shapeList) - 3:
640 new_shapeList.append(shape.copy())
641
642 return new_shapeList
643
644
645class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100646 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100647
648 def __init__(self):
649 pass
650
Jeremy Johnson1271c442023-09-05 11:39:26 +0100651 class TVGInfo:
652 """Enhanced tensor values information including data gen dict."""
653
654 def __init__(self, tensorList, dataGenDict):
655 self.tensorList = tensorList
656 self.dataGenDict = dataGenDict
657
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100658 # Default high value for random numbers
659 TVG_FLOAT_HIGH_VALUE = {
660 DType.FP32: (1 << 128) - (1 << (127 - 23)),
661 DType.FP16: (1 << 16) - (1 << (15 - 10)),
662 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000663 DType.FP8E4M3: 448,
664 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100665 }
666
Jeremy Johnson30476252023-11-20 16:15:30 +0000667 # Default lowest normal values for random numbers
668 TVG_FLOAT_LOW_VALUE = {
669 DType.FP32: np.exp2(-126),
670 DType.FP16: np.exp2(-14),
671 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000672 DType.FP8E4M3: np.exp2(-9),
673 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000674 }
675
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100676 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100677 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # Return a tuple of (low,high) data range values for the given data
679 # type using a combination of per operator table limits, data limits
680 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000681 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100682 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000683 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000684 if lowValueLookup is not None and dtype in lowValueLookup:
685 low_val = lowValueLookup[dtype]
686 else:
687 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000688 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000689 # respecting the default ranges if more/less than the low/high
690 # values
691 data_range = (
692 max(low_val, type_range[0]),
693 min(high_val, type_range[1]),
694 )
695 if data_range[0] > data_range[1]:
696 # Invalid data range from low to high created due to user
697 # constraints revert to using internal ranges as they are
698 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000699 logger.info(
700 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
701 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000702 data_range = (low_val, high_val)
703 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000704 return None
705
706 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100707 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100708 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709 ):
710 # Variable inputs versus constants
711 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000712 if "p_count" in argsDict:
713 # Override for operators like CONCAT
714 pCount = argsDict["p_count"]
715 cCount = argsDict["c_count"]
716 assert pCount + cCount == len(
717 shapeList
718 ), "Placeholders & Constant tensors must match shapes list"
719
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000720 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100721
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100722 if (
723 error_name is not None
724 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100725 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100726 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000727 # Fall back to internal data gen when dealing with unsupported types or ops
728 data_range = argsDict["data_range"] if "data_range" in argsDict else None
729 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000730 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000731 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000732 if "data_range_list" in argsDict:
733 data_range = argsDict["data_range_list"][idx]["range"]
734 roundMode = (
735 "round" in argsDict["data_range_list"][idx]
736 and argsDict["data_range_list"][idx]["round"] is True
737 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000738 if data_range is not None and dtype not in (
739 DType.FP16,
740 DType.FP32,
741 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000742 DType.FP8E4M3,
743 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000744 ):
745 # Change from inclusive to exclusive range
746 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000747
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100748 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000749 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000750 if dtype == DType.SHAPE:
751 arr = np.int64(argsDict["fixed_data"][idx])
752 elif dtype == DType.INT8:
753 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000754 elif dtype == DType.INT16:
755 arr = np.int16(argsDict["fixed_data"][idx])
756 elif dtype == DType.INT32:
757 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000758 else:
759 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000760 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100761 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000762 if roundMode:
763 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000764 if idx < pCount:
765 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
766 else:
767 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100768
Jeremy Johnson1271c442023-09-05 11:39:26 +0100769 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
770
771 # Create data generator meta-data
772 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100773 tens_data = {
774 "version": "0.1",
775 "tensors": {},
776 }
777 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100778 for idx, shape in enumerate(shapeList):
779
780 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000781 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
782 tens_meta["generator"] = gtu.DataGenType(
783 gtu.DataGenType.FIXED_DATA
784 ).name
785 else:
786 tens_meta["generator"] = gtu.DataGenType(dg_type).name
787
Jeremy Johnson1271c442023-09-05 11:39:26 +0100788 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
789 tens_meta["shape"] = [int(i) for i in shape]
790 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100791 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100792
793 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100794 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100795 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100796 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100797
798 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
799 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000800 if (
801 tens_meta["generator"]
802 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
803 ):
804 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
805 tens_meta["fixed_data_info"] = info
806 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100807 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000808
Won Jeon64e4bfe2024-01-18 06:31:55 +0000809 data_range = None
810 if "data_range_list" in argsDict:
811 data_range = argsDict["data_range_list"][idx]["range"]
812 if "round" in argsDict["data_range_list"][idx]:
813 info["round"] = argsDict["data_range_list"][idx]["round"]
814 elif "data_range" in argsDict:
815 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000816
Won Jeon64e4bfe2024-01-18 06:31:55 +0000817 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100818 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000819 info["range"] = [str(v) for v in data_range]
820 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100821 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
822 info = {}
823 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100824 info["ks"] = int(argsDict["ks"])
825 if "acc_type" in argsDict:
826 # Convert type number into JSON name
827 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
828 "json"
829 ]
830 if "kernel" in argsDict:
831 info["kernel"] = [int(k) for k in argsDict["kernel"]]
832 if "axis" in argsDict:
833 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100834 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000835 elif dg_type == gtu.DataGenType.FULL_RANGE:
836 info = {}
837 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100838 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000839 )
840 tens_meta["full_range_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100841 else:
842 # TODO - other data gen type
843 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100844
845 # Using the finished generate config meta data - generate the data if
846 # needed and assign a tensor name from the serializer
847
848 # Need to generate data when not lazy or for the bias tensor as we need
849 # to work out if the bias data is non-zero for compliance
850 if not testGen.args.lazy_data_gen or (
851 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
852 ):
853 # Give this tensor a temporary name until we get one from the serializer
854 temp_name = f"placeholder_{idx}"
855 dg_tens_meta[temp_name] = tens_meta
856 # Create data now using the temporary name to access meta details
857 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000858 if tens_meta["data_type"] == "SHAPE":
859 # Tensor type SHAPE and Numpy file type must be the same
860 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100861 # Remove the item as we will give it the correct name later
862 del dg_tens_meta[temp_name]
863
864 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
865 # The KS value used by compliance verification is altered when the
866 # bias data is non-zero
867 if max(abs(data)) > 0.0:
868 argsDict["ksb"] = argsDict["ks"] + 1
869
870 if testGen.args.lazy_data_gen:
871 data = None
872
873 if tens_meta["input_type"] == "VARIABLE":
874 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
875 else:
876 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
877
878 tens_ser_list.append(tens)
879 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100880 dg_tens_meta[tens.name] = tens_meta
881
Jeremy Johnson1271c442023-09-05 11:39:26 +0100882 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
883
884 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100885 def tvgNegate(
886 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
887 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100888 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000889 # Integer test
890 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100891 pCount, cCount = op["operands"]
892 assert (
893 pCount == 1 and cCount == 0
894 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100895 # Must create tensors with values within accumulator (int32) negatable
896 # range
897 max_val = (1 << 31) - 1
898 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100899 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100900 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000902 tens_ser_list = []
903 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
905 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000906 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000908 # ERROR_IF or floating point test
909 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100910 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100911 )
912
Jeremy Johnson30476252023-11-20 16:15:30 +0000913 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000914 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
915 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
916 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
917 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
918 }
919
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100921 def tvgAddSub(
922 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
923 ):
Won Jeon74342e52024-01-09 00:34:40 +0000924 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000925 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000927 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100928 pCount, cCount = op["operands"]
929 assert (
930 pCount == 2 and cCount == 0
931 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000932 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000933 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
934 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100935 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
936 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100937 if add:
938 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
939 else:
940 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
941
942 # Work out the saturation limits
943 max_i32 = (1 << 31) - 1
944 min_i32 = -(1 << 31)
945 max_arr = np.full(shapeList[1], max_i32)
946 min_arr = np.full(shapeList[1], min_i32)
947
948 # Find how much values exceed the maximum/minimums
949 sat_max_arr = np.maximum(res_arr - max_arr, 0)
950 sat_min_arr = np.minimum(res_arr - min_arr, 0)
951
952 if not add:
953 # Swap saturation values and negate values as we need to perform opposite operations
954 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
955
956 # Create new array of unsaturated values by clipping values as needed
957 b_unsat_arr = b_arr
958 if (sat_max_arr != 0).any():
959 # Clip values that cause saturation
960 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
961 # Reduce axes in unsaturated tensor to match original tensor
962 for axis, dim in enumerate(b_arr.shape):
963 if dim != b_unsat_arr.shape[axis]:
964 assert (
965 dim == 1
966 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
967 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
968
969 if (sat_min_arr != 0).any():
970 # Clip values that cause saturation
971 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
972 # Reduce axes in unsaturated tensor to match original tensor
973 for axis, dim in enumerate(b_arr.shape):
974 if dim != b_unsat_arr.shape[axis]:
975 assert (
976 dim == 1
977 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
978 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
979
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000980 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100981 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
982 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000983 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100984 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
985 )
986
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000987 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100988 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000989 # ERROR_IF or floating point test
990 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100991 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000992 )
993 if data_range:
994 argsDict["data_range"] = data_range
995
996 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100997 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100998 )
999
1000 @staticmethod
1001 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001002 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001003 ):
1004 if dtypeList[0] in (
1005 DType.INT32,
1006 DType.INT16,
1007 DType.INT8,
1008 ):
1009 # Limit input tensors with cond_if_binary or while_loop to stop
1010 # saturation of add/sub ops with int32 and keep all logical shift
1011 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001012 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001013 pCount, cCount = op["operands"]
1014 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001015 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001016 for idx, shape in enumerate(shapeList[:]):
1017 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001018 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001020 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001022 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001023 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1024 )
1025 pRemain -= 1
1026 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001027 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001028 testGen.ser.addConst(shape, dtypeList[idx], arr)
1029 )
1030
Jeremy Johnson587cc842024-02-08 11:45:44 +00001031 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001032 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001033 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001034 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001035 )
1036
1037 @staticmethod
1038 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001039 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001040 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001041 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001042 pCount, cCount = op["operands"]
1043 # Force value of operand[1] to be within [0, num_bits]
1044 assert (
1045 pCount == 2 and cCount == 0
1046 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1047
Jeremy Johnson587cc842024-02-08 11:45:44 +00001048 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001049 for idx, shape in enumerate(shapeList[:]):
1050 if idx == 1:
1051 if dtypeList[idx] == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001052 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001053 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001054 arr = np.int32(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001056 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001058 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059 else:
1060 raise Exception("OpArithmeticRightShift: invalid input dtype")
1061 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001062 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001063 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064
Jeremy Johnson587cc842024-02-08 11:45:44 +00001065 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001066
1067 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001068 def tvgReshape(
1069 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1070 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001071 dtypeList[1] = DType.SHAPE
1072 shapeList[1] = [len(argsDict["new_shape"])]
1073 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1074 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1075
1076 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001077 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001078 )
1079
1080 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001081 def tvgRescale(
1082 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1083 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001084 scale32 = argsDict["scale"]
1085 multiplier_arr = argsDict["multiplier"]
1086 shift_arr = argsDict["shift"]
1087
1088 if scale32:
1089 dtypeList[1] = DType.INT32
1090 else:
1091 dtypeList[1] = DType.INT16
1092 shapeList[1] = [len(multiplier_arr)]
1093 dtypeList[2] = DType.INT8
1094 shapeList[2] = [len(shift_arr)]
1095 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1096 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1097
1098 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001099 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001100 )
1101
1102 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001103 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001104 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1105 pad_values = argsDict["pad"].flatten()
1106 dtypeList[1] = DType.SHAPE
1107 shapeList[1] = [len(pad_values)]
1108 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1109 argsDict["fixed_data"] = [None, pad_values]
1110
1111 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001112 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001113 )
1114
1115 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001116 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001117 dtypeList[1] = DType.SHAPE
1118 shapeList[1] = [len(argsDict["start"])]
1119 dtypeList[2] = DType.SHAPE
1120 shapeList[2] = [len(argsDict["size"])]
1121 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1122 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1123
1124 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001125 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001126 )
1127
1128 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001129 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001130 dtypeList[1] = DType.SHAPE
1131 shapeList[1] = [len(argsDict["multiples"])]
1132 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1133
1134 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001135 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001136 )
1137
1138 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001139 def tvgSelect(
1140 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1141 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001142 # Set datatype of condition tensor to boolean
1143 dtypeList[0] = DType.BOOL
1144
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001145 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001146 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001147 )
1148
1149 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001150 def tvgIntDiv(
1151 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1152 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001153 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001154 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 pCount, cCount = op["operands"]
1156 assert (
1157 pCount == 2 and cCount == 0
1158 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1159
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001160 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161
1162 # Two invalid cases for Op.INTDIV:
1163 # 1. divisor == 0
1164 # 2. dividend == -(1<<31) and divisor == -1
1165 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001166 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1167 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001168
1169 if (divisor_arr == 0).any():
1170 continue
1171
1172 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1173 continue
1174
1175 break
1176
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001177 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001178 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1179 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001180 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001181 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1182 )
1183
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001184 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001185 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001186 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001187 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001188 )
1189
Jeremy Johnson30476252023-11-20 16:15:30 +00001190 # Set the MUL data range to the square root of the largest value
1191 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001192 TVG_FLOAT_HIGH_VALUE_MUL = {
1193 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1194 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1195 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1196 }
1197
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001198 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001199 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001200 if error_name is not None or dtypeList[0] in (
1201 DType.FP16,
1202 DType.BF16,
1203 DType.FP32,
1204 ):
1205 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001206 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001207 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001208 )
1209 if data_range:
1210 argsDict["data_range"] = data_range
1211
Jeremy Johnson0a042992024-02-28 13:20:05 +00001212 if dtypeList[0] != DType.SHAPE:
1213 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1214 dtypeList[2] = DType.INT8
1215 shapeList[2] = [1]
1216 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1217 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1218
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001219 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001220 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001221 )
1222 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001223 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001224 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001225
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001226 tens_ser_list = []
1227
1228 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001229 if dtypeList[0] == DType.SHAPE:
1230 shift = 0
1231 else:
1232 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001233 if dtypeList[0] == DType.INT8:
1234 num_bits = 8
1235 elif dtypeList[0] == DType.INT16:
1236 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001237 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001238 num_bits = 32
1239 elif error_name == ErrorIf.WrongInputType:
1240 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001241 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001242 raise Exception(
1243 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1244 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001245
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001246 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001247 if dtypeList[idx] == DType.SHAPE:
1248 low = testGen.args.tensor_shape_range[0]
1249 high = testGen.args.tensor_shape_range[1]
1250 else:
1251 low = -(2 ** (num_bits - 1))
1252 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001253
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001254 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1255 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001256
1257 i = 0
1258 while True:
1259
1260 a_arr_64 = a_arr.astype(np.int64)
1261 b_arr_64 = b_arr.astype(np.int64)
1262
1263 if shift > 0:
1264 rounding = 1 << (shift - 1)
1265 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001266 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001267 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001268
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001269 if (result_arr > -(2**31)).all() and (
1270 result_arr <= ((2**31) - 1)
1271 ).all():
1272 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001273
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001274 i = i + 1
1275 a_arr = a_arr // 2
1276 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001277
Won Jeon74342e52024-01-09 00:34:40 +00001278 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001279 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001280 tens_ser_list.append(
1281 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1282 )
1283 tens_ser_list.append(
1284 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1285 )
1286 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001287 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001288 tens_ser_list.append(
1289 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1290 )
1291 tens_ser_list.append(
1292 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1293 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001294 tens_ser_list.append(
1295 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1296 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001297
1298 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001299
1300 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001301 def tvgConcat(
1302 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1303 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001304 count = len(shapeList) - testGen.args.num_const_inputs_concat
1305 if count < 1:
1306 count = 1
1307 if testGen.args.num_const_inputs_concat == 0:
1308 count = len(shapeList)
1309
Won Jeon74342e52024-01-09 00:34:40 +00001310 op = testGen.TOSA_OP_LIST[opName]
1311 if op["op"] == Op.CONCAT_SHAPE:
1312 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001313 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001314 else:
1315 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001316 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001317 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001319 # Override default pCount/cCount for operator
1320 argsDict["p_count"] = count
1321 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001322
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001323 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001324 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001325 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001326
1327 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001328 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001329 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001330 ):
1331 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001332 pCount, cCount = op["operands"]
1333 assert (
1334 pCount == 2 and cCount == 0
1335 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001336 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1337 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001338 tens_ser_list = []
1339 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001340 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1341 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001342 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001343 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1344 )
1345
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001346 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001347
1348 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001349 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001350 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1351 # Integer
1352 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001353 pCount, cCount = op["operands"]
1354 assert (
1355 pCount == 2 and cCount == 0
1356 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001357
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001358 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1359 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001360
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001361 # Using random numbers means that it will be very unlikely that
1362 # there are any matching (equal) values, therefore force that
1363 # there are twice the number of matching values as the tensor rank
1364 for num in range(0, len(shapeList[0]) * 2):
1365 a_index = []
1366 b_index = []
1367 # Choose an index in each axis for the whole shape
1368 for axis in range(0, len(shapeList[0])):
1369 # Index can be up to the largest dimension in both shapes
1370 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001371 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001372 )
1373 # Reduce the index down to a shape's dim for broadcasting
1374 a_index.append(min(shapeList[0][axis] - 1, index))
1375 b_index.append(min(shapeList[1][axis] - 1, index))
1376
1377 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1378
Jeremy Johnsona0150012023-11-15 15:52:06 +00001379 tens_ser_list = []
1380 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001381 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1382 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001383 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001384 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1385 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001386 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001387 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001388 # ERROR_IF or floating point test
1389 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001390 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001391 )
1392
1393 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001394 def tvgReduceSum(
1395 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1396 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001397 dtype = dtypeList[0]
1398 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001399 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001400 pCount, cCount = op["operands"]
1401 assert (
1402 pCount == 1 and cCount == 0
1403 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1404 # Limit values so that the sum cannot exceed the range of an int32 during
1405 # summation of any axis
1406 range_val = int((1 << 31) / max(shapeList[0]))
1407 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001408 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001409 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001410 tens_ser_list = []
1411 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001412 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001413 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001414 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001415 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001417 if (
1418 error_name is None
1419 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1420 ):
1421 # Limit ranges for (non error & non compliance) tests by using
1422 # values that can be summed on any axis to not hit infinity
1423 highval_lookup = {
1424 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1425 / max(shapeList[0])
1426 }
1427 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001428 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001429 )
1430 assert data_range is not None
1431 argsDict["data_range"] = data_range
1432
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001433 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001434 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001435 )
1436
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001437 @staticmethod
1438 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001439 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001440 ):
1441 dtype = dtypeList[0]
1442 if error_name is None:
1443 # Limit ranges for (non error) tests by using
1444 # values that can be multiplied on any axis to not hit infinity
1445 highval_lookup = {
1446 dtype: math.pow(
1447 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1448 1 / max(shapeList[0]),
1449 )
1450 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001451 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001452 assert data_range is not None
1453 argsDict["data_range"] = data_range
1454
1455 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001456 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001457 )
1458
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001459 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001460 def tvgResize(
1461 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1462 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001463 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001464 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001465 dtypeList[0],
1466 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1467 )
1468 if data_range:
1469 argsDict["data_range"] = data_range
1470 # Needed for compliance
1471 argsDict["max_abs_value"] = data_range[1]
1472
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001473 scale_values = argsDict["scale"]
1474 offset_values = argsDict["offset"]
1475 border_values = argsDict["border"]
1476 dtypeList[1] = DType.SHAPE
1477 dtypeList[2] = DType.SHAPE
1478 dtypeList[3] = DType.SHAPE
1479 shapeList[1] = [len(scale_values)]
1480 shapeList[2] = [len(offset_values)]
1481 shapeList[3] = [len(border_values)]
1482 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1483
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001484 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001485 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001486 )
1487
Jeremy Johnson30476252023-11-20 16:15:30 +00001488 # Set the POW exponent high data range
1489 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1490 DType.FP32: 10.0,
1491 DType.FP16: 10.0,
1492 DType.BF16: 10.0,
1493 }
1494 # POW highest base value (within a safe margin of error) that can be raised
1495 # to +ve exponent that doesn't become Infinity
1496 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1497 DType.FP32: math.floor(
1498 math.pow(
1499 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1500 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1501 )
1502 ),
1503 DType.FP16: math.floor(
1504 math.pow(
1505 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1506 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1507 )
1508 ),
1509 DType.BF16: math.floor(
1510 math.pow(
1511 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1512 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1513 )
1514 ),
1515 }
1516 # POW lowest base value (within a safe margin of error) that can be raised
1517 # to -ve exponent that doesn't become Infinity
1518 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1519 DType.FP32: math.ceil(
1520 math.pow(
1521 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1522 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1523 )
1524 * 1000
1525 )
1526 / 1000,
1527 DType.FP16: math.ceil(
1528 math.pow(
1529 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1530 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1531 )
1532 * 1000
1533 )
1534 / 1000,
1535 DType.BF16: math.ceil(
1536 math.pow(
1537 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1538 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1539 )
1540 * 1000
1541 )
1542 / 1000,
1543 }
1544
1545 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001546 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001547 if error_name is not None:
1548 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001549 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001550 )
1551 dtype = dtypeList[0]
1552 # Different ranges for POW
1553 test_set = argsDict["s"]
1554 if test_set == 0:
1555 # Positive base with fractional exponent
1556 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001557 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001558 dtype,
1559 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1560 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1561 )
1562 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001563 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001564 )
1565 exp_round = False
1566 else:
1567 # Integer exponent
1568 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001569 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001570 )
1571 exp_round = True
1572 if test_set == 1:
1573 # Positive base
1574 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001575 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001576 dtype,
1577 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1578 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1579 )
1580 else:
1581 assert test_set == 2
1582 # Negative base
1583 # Supply new look up tables with negative values
1584 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001585 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001586 dtype,
1587 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1588 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1589 )
1590
1591 data_range_list = (
1592 {
1593 "range": base_range,
1594 },
1595 {
1596 "range": exp_range,
1597 "round": exp_round,
1598 },
1599 )
1600 argsDict["data_range_list"] = data_range_list
1601 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001602 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001603 )
1604
1605 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001606 def tvgLogRsqrt(
1607 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1608 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001609 # LOG & RSQRT data range from lowest expressible positive number to
1610 # largest to avoid NaNs
1611 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001612 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001613 dtypeList[0],
1614 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1615 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1616 )
1617 if data_range:
1618 argsDict["data_range"] = data_range
1619
1620 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001621 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001622 )
1623
1624 # Set the EXP data range to the log of the largest to smallest values
1625 # to avoid infinities or making the result zero
1626 TVG_FLOAT_HIGH_VALUE_EXP = {
1627 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1628 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1629 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1630 }
1631 TVG_FLOAT_LOW_VALUE_EXP = {
1632 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1633 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1634 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1635 }
1636
1637 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001638 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001639 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001640 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001641 dtypeList[0],
1642 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1643 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1644 )
1645 if data_range:
1646 argsDict["data_range"] = data_range
1647
1648 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001649 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001650 )
1651
1652 @staticmethod
1653 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001654 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001655 ):
1656 dtype = dtypeList[0]
1657 if (
1658 error_name is None
1659 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001660 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001661 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001662 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001663 # Limit ranges for (non error & non compliance) FP tests by using
1664 # values that can be multiplied on any axis to not hit infinity/NaN
1665 IC = shapeList[0][1]
1666 highval_lookup = {
1667 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1668 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001669 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001670 assert data_range is not None
1671 argsDict["data_range"] = data_range
1672
1673 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001674 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001675 )
1676
Jeremy Johnson708da822023-11-15 16:25:45 +00001677 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001678 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001679 in_dtype = dtypeList[0]
1680 out_dtype = argsDict["out_type"]
1681 # Create look up to limit input tensor to output type maximums to avoid
1682 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001683 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001684 highval_lookup = {in_dtype: out_range[1]}
1685 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001686 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001687 in_dtype,
1688 highval_lookup,
1689 )
1690
1691 assert data_range is not None
1692 argsDict["data_range"] = data_range
1693
1694 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001695 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001696 )
1697
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001698 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001699 def tvgGather(
1700 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1701 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001702 K = shapeList[0][1]
1703
1704 # Fix the type of the indices tensor
1705 dtypeList[1] = DType.INT32
1706
1707 dtype = dtypeList[0]
1708 if not gtu.dtypeIsSupportedByCompliance(dtype):
1709 # Test unsupported by data generator
1710 op = testGen.TOSA_OP_LIST[opName]
1711 pCount, cCount = op["operands"]
1712 assert (
1713 pCount == 2 and cCount == 0
1714 ), "Op.GATHER must have 2 placeholders, 0 consts"
1715
1716 tens_ser_list = []
1717 for idx, shape in enumerate(shapeList):
1718 dtype = dtypeList[idx]
1719 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001720 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001721 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1722 else:
1723 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001724 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001725 # To match old functionality - create indices as CONST
1726 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1727
1728 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1729
1730 else:
1731 # ERROR_IF or floating point test
1732 # Use inclusive values upto index K for indices tensor
1733 data_range_list = (
1734 {"range": None},
1735 {"range": (0, K - 1)},
1736 )
1737 argsDict["data_range_list"] = data_range_list
1738
1739 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001740 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001741 )
1742
1743 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001744 def tvgScatter(
1745 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1746 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001747 K = shapeList[0][1]
1748 W = shapeList[2][1]
1749
1750 # Work out an indices tensor here with data that doesn't exceed the
1751 # dimension K of the values_in tensor and does NOT repeat the same K
1752 # location as needed by the spec:
1753 # "It is not permitted to repeat the same output index within a single
1754 # SCATTER operation and so each output index occurs at most once."
1755 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1756
1757 # Fix the type of the indices tensor
1758 dtypeList[1] = DType.INT32
1759
1760 dtype = dtypeList[0]
1761 if not gtu.dtypeIsSupportedByCompliance(dtype):
1762 # Test unsupported by data generator
1763 op = testGen.TOSA_OP_LIST[opName]
1764 pCount, cCount = op["operands"]
1765 assert (
1766 pCount == 3 and cCount == 0
1767 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1768
1769 tens_ser_list = []
1770 for idx, shape in enumerate(shapeList):
1771 dtype = dtypeList[idx]
1772 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001773 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001774 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1775 else:
1776 # Create the indices array
1777 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1778 arr = []
1779 for n in range(shape[0]):
1780 # Get a shuffled list of output indices (0 to K-1) and
1781 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001782 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001783 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1784 # To match old functionality - create indices as CONST
1785 tens_ser_list.append(
1786 testGen.ser.addConst(shape, dtype, indices_arr)
1787 )
1788
1789 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1790
1791 else:
1792 # ERROR_IF or floating point test
1793 # Use inclusive values upto index K for indices tensor
1794 data_range_list = (
1795 {"range": None},
1796 {"range": (0, K - 1)},
1797 {"range": None},
1798 )
1799 argsDict["data_range_list"] = data_range_list
1800
1801 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001802 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001803 )
1804
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001805
1806class TosaArgGen:
1807 """Argument generators create exhaustive or random lists of attributes for
1808 operators that take attributes or other parameters.
1809
1810 The return value is a list of (descriptive_name, [arglist]) tuples where
1811 the descriptive_name is appended to the test name and the arglist is expanded
1812 as arguments to the operator build function.
1813 """
1814
1815 def __init__(self):
1816 pass
1817
1818 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001819 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001820 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001821 if (
1822 error_name is None
1823 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1824 and gtu.dtypeIsSupportedByCompliance(dtype)
1825 ):
Tai Ly60dc48c2024-03-08 22:19:41 +00001826 if gtu.dtypeIsFloat(dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001827 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1828 else:
1829 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1830 else:
1831 # Error test or No data generator types listed - assume random
1832 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1833
1834 # Expand arg list with other data generator types
1835 new_arg_list = []
1836 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001837 for arg_str, args_dict in arg_list:
evacha019c96eef2024-02-07 11:21:55 +00001838
1839 if dg_type == gtu.DataGenType.FULL_RANGE:
1840 tensor_size = gtu.product(shapeList[0])
1841 if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1842 # Large enough tensor data size for full range, add a single test
1843 num_test_sets = 0
1844 else:
1845 # Not enough data size for full range of values, revert to random numbers
1846 dg_type = gtu.DataGenType.PSEUDO_RANDOM
1847
Jeremy Johnson1271c442023-09-05 11:39:26 +01001848 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001849 if error_name is None:
1850 num_test_sets = (
1851 args_dict["num_test_sets"]
1852 if "num_test_sets" in args_dict
1853 else 0
1854 )
1855 else:
evacha019c96eef2024-02-07 11:21:55 +00001856 # Add single test for pseudo random
Jeremy Johnson30476252023-11-20 16:15:30 +00001857 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001858
1859 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1860 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001861 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001862 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001863 shape_info = (
1864 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1865 if "shape" in args_dict
1866 else ""
1867 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001868 logger.info(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001869 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001870 )
1871 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001872 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001873 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001874 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001875
Jeremy Johnson30476252023-11-20 16:15:30 +00001876 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1877
1878 if num_test_sets > 0:
1879 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001880 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
1881 set_args_dict = args_dict.copy()
1882 set_args_dict["s"] = s
1883 set_args_dict["dg_type"] = dg_type
1884 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001885 else:
1886 # Default is a single test
evacha019c96eef2024-02-07 11:21:55 +00001887 new_args_dict = args_dict.copy()
1888 new_args_dict["dg_type"] = dg_type
1889 new_arg_list.append((arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001890
1891 return new_arg_list
1892
1893 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001894 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001895 """A trivial argument generator for operators that don't take any
1896 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001897 arg_list = TosaArgGen._add_data_generators(
1898 testGen,
1899 opName,
evacha019c96eef2024-02-07 11:21:55 +00001900 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001901 dtype,
1902 [("", {})],
1903 error_name,
1904 )
1905 # Return list of tuples: (arg_str, args_dict)
1906 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001907
1908 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001909 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001910 """Pow operator needs different test sets to cover random numbers
1911 without creating NaNs or Infs"""
1912 arg_list = TosaArgGen._add_data_generators(
1913 testGen,
1914 opName,
evacha019c96eef2024-02-07 11:21:55 +00001915 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001916 dtype,
1917 [("", {"num_test_sets": 3})],
1918 error_name,
1919 )
1920 # Return list of tuples: (arg_str, args_dict)
1921 return arg_list
1922
1923 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001924 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001925 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001926 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001927 shape = shapeList[0]
1928
1929 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001930 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001931 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001932 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001933 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001934 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001935 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001936 # Create tests for each dimension
1937 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001938
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001939 opid = testGen.TOSA_OP_LIST[opName]["op"]
1940
1941 for a in axes:
1942 args_dict = {"axis": int(a)}
1943 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001944 output_shape = shape.copy()
1945 if error_name is None:
1946 # It only matters that we calculate the dot_products correctly
1947 # for non error_if tests as they should never be run
1948 output_shape[a] = 1
1949 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001950 args_dict["shape"] = shape
1951 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1952 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1953
1954 arg_list.append(("axis{}".format(a), args_dict))
1955
1956 arg_list = TosaArgGen._add_data_generators(
1957 testGen,
1958 opName,
evacha019c96eef2024-02-07 11:21:55 +00001959 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001960 dtype,
1961 arg_list,
1962 error_name,
1963 )
1964 # Return list of tuples: (arg_str, args_dict)
1965 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001966
1967 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001968 def _calculate_sparsity(num_tests, sparsity_factor):
1969 sparsity = num_tests // sparsity_factor + 1
1970 # If there are only a small number of tests, just select them all
1971 if sparsity < 13:
1972 sparsity = 1
1973 # To get a variety of parameter combinations sparsity should not be a
1974 # multiple of 2, 3 or 5
1975 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1976 sparsity += 1
1977 return sparsity
1978
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001979 # Maximum number of error_if variants to produce
1980 MAX_CONV_ERROR_IFS = 3
1981
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001982 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001983 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001984 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001985 arg_list = []
1986
Jeremy Johnson0c716862023-04-13 17:18:19 +01001987 if testGen.args.level8k and error_name is not None:
1988 # Don't produce negative large tests
1989 return arg_list
1990
1991 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001992 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001993 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001994 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001995
Tai Lyf36f2562024-03-14 16:21:29 +00001996 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
1997
1998 if error_name == ErrorIf.WrongAccumulatorType:
1999 accum_dtypes = (
2000 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2001 )
James Ward8b390432022-08-12 20:48:56 +01002002
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002003 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01002004 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002005 depthwise = opName.startswith("depthwise")
2006
2007 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01002008 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002009 if error_name != ErrorIf.WrongRank:
2010 assert len(ifm_shape) == rank
2011 assert len(filter_shape) == rank
2012
Jeremy Johnson0c716862023-04-13 17:18:19 +01002013 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002014 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002015 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002016 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002017 # compliance size - KS
2018 k_size = gtu.product(k_shape)
2019 if not depthwise:
2020 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002021
Jeremy Johnson0c716862023-04-13 17:18:19 +01002022 if not testGen.args.level8k:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002023 if error_name in (
2024 ErrorIf.PadSmallerZero,
2025 ErrorIf.StrideSmallerOne,
2026 ErrorIf.DilationSmallerOne,
2027 ):
2028 # Use specific invalid value(s)
2029 if error_name == ErrorIf.PadSmallerZero:
2030 # Create negative paddings but with positive opposite paddings
2031 neg_pad = rng.choice(range(-5, 0))
2032 p_vals = [neg_pad, abs(neg_pad)]
2033 else:
2034 p_vals = [0, 0]
2035 if error_name == ErrorIf.StrideSmallerOne:
2036 # Can't use stride=0, as it is used to derive output shape, as a divisor
2037 s_vals = [rng.choice(range(-5, 0))]
2038 else:
2039 s_vals = [1]
2040 if error_name == ErrorIf.DilationSmallerOne:
2041 d_vals = [rng.choice(range(-5, 1))]
2042 else:
2043 d_vals = [1]
2044 p = p_vals * k_rank
2045 s = s_vals * k_rank
2046 d = d_vals * k_rank
2047
2048 # Fix values to produce valid error_if
2049 for index in range(k_rank):
2050 pad_offset = index * 2
2051 fixed = False
2052 while not fixed:
2053 partial = (
2054 ifm_shape[index + 1]
2055 - 1
2056 + p[pad_offset]
2057 + p[pad_offset + 1]
2058 - (k_shape[index] - 1) * d[index]
2059 )
2060 remainder = partial % s[index]
2061 if partial <= 0:
2062 p[pad_offset + 1] += abs(partial) + 1
2063 elif remainder:
2064 # Stride will be negative for StrideSmallerOne
2065 assert remainder < 0
2066 p[pad_offset + 1] += abs(remainder)
2067 else:
2068 fixed = True
2069 paddings = {tuple(p)}
2070 strides = {tuple(s)}
2071 dilations = {tuple(d)}
2072 logger.debug(f"agConv: error pad={p} stride={s} dilation={d}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01002073 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002074 # Generate comprehensive argument lists
Jeremy Johnson0c716862023-04-13 17:18:19 +01002075 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002076 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
Jeremy Johnson0c716862023-04-13 17:18:19 +01002077 # Stride must be greater than 1 to force non-integer error
2078 startStride = (
2079 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002080 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002081 s_vals = [
2082 x for x in range(startStride, testGen.args.max_conv_stride + 1)
2083 ]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002084 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002085
2086 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2087 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002088
Jeremy Johnson0c716862023-04-13 17:18:19 +01002089 if not error_name and testGen.args.oversize:
2090 # add some oversize argument values
2091 if max(ifm_shape) < 64:
2092 bigPadding = 9
2093 paddings.update(
2094 {
2095 x
2096 for x in itertools.product(
2097 *([[0, bigPadding]] * (k_rank * 2))
2098 )
2099 }
2100 )
2101 bigStride = 8
2102 strides.update(
2103 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2104 )
2105 bigDilation = 7
2106 dilations.update(
2107 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2108 )
2109 max_dim_size = None
2110
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002111 if error_name:
2112 # Cycle through all error_if tests but we only keep the first few
2113 sparsity = 1
2114 else:
2115 # There are too many parameter combinations, so generate them sparsely,
2116 sparsity_factor = 120
2117 sparsity = TosaArgGen._calculate_sparsity(
2118 len(paddings) * len(strides) * len(dilations), sparsity_factor
2119 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002120 else:
2121 # Only test 8k levels boundaries
2122 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2123 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2124 bigPadding = bigKernel
2125
2126 dilation_shape = [1] * k_rank
2127 pad_shape = [0] * k_rank * 2
2128 if conv3d:
2129 # Small stride apart from for big kernel (see below) to keep
2130 # tensor size/calculation small
2131 stride_shape = [1] * k_rank
2132 for idx in range(k_rank):
2133 pad_offset = idx * 2
2134 if k_shape[idx] == bigKernel:
2135 # Padding shape needs to account for tensor shape
2136 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2137 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2138 # Big stride to reduce output size
2139 stride_shape[idx] = bigKernel
2140 else:
2141 # Account for kernel size
2142 pad_shape[pad_offset] = k_shape[idx] - 1
2143 else:
2144 # Always have a large stride with extra padding and dilation to keep
2145 # tensor calculation reasonable
2146 stride_shape = [bigKernel] * k_rank
2147 for idx in range(k_rank):
2148 # Dilation shape must account for kernel size
2149 dilation_shape[idx] = bigKernel // k_shape[idx]
2150 # Padding shape needs to accommodate tensor/kernel & dilation
2151 pad_offset = idx * 2
2152 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2153 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2154
2155 strides = {tuple(stride_shape)}
2156 dilations = {tuple(dilation_shape)}
2157 paddings = {tuple(pad_shape)}
2158 # Create a limit for the output dimensions size
2159 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2160
2161 # Currently allow all combinations that are reasonable size
2162 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002163
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002164 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002165 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002166 for a in accum_dtypes:
2167 for s in sorted(list(strides)):
2168 for p in sorted(list(paddings)):
2169 for d in sorted(list(dilations)):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002170 if (
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002171 more_tests
2172 and n % sparsity == 0
Tai Lyf36f2562024-03-14 16:21:29 +00002173 # the padded shape must exceed the dilation * kernel to get a positive
2174 # sized output shape
2175 and (ifm_shape[1] - 1 + p[0] + p[1])
2176 > d[0] * (k_shape[0] - 1)
2177 and (ifm_shape[2] - 1 + p[2] + p[3])
2178 > d[1] * (k_shape[1] - 1)
2179 and (
2180 k_rank < 3
2181 or (
2182 (ifm_shape[3] - 1 + p[4] + p[5])
2183 > d[2] * (k_shape[2] - 1)
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002184 )
2185 )
Tai Lyf36f2562024-03-14 16:21:29 +00002186 ):
2187 remainders = []
2188 outputs = []
2189 for index in range(k_rank):
2190 pad_offset = index * 2
2191 partial = (
2192 ifm_shape[index + 1]
2193 - 1
2194 + p[pad_offset]
2195 + p[pad_offset + 1]
2196 - (k_shape[index] - 1) * d[index]
2197 )
2198 remainders.append(partial % s[index])
2199 outputs.append((partial // s[index]) + 1)
2200
2201 if (
2202 # the parameters must produce integer exact output
2203 error_name != ErrorIf.ConvOutputShapeNonInteger
2204 and max(remainders) == 0
2205 ) or (
2206 error_name == ErrorIf.ConvOutputShapeNonInteger
2207 and max(remainders) > 0
2208 ):
2209 if (
2210 max_dim_size is not None
2211 and max(outputs) >= max_dim_size
2212 ):
2213 # Test will consume too much memory - skip it
2214 continue
2215
2216 # Compliance - number of dot product calculations
2217 if depthwise:
2218 # N*OH*OW*C*M
2219 dots = gtu.product(
2220 (ifm_shape[0], *outputs, *filter_shape[2:])
2221 )
2222 else:
2223 # N*OH*OW*OC or N*OD*OH*OW*OC
2224 dots = gtu.product(
2225 (ifm_shape[0], *outputs, filter_shape[0])
2226 )
2227 args_dict = {
2228 "acc_type": a,
2229 "stride": s,
2230 "pad": p,
2231 "dilation": d,
2232 "kernel": k_shape,
2233 "ks": k_size,
2234 "dot_products": dots,
2235 "shape": ifm_shape,
2236 }
2237
2238 # Support for larger values than 9 needs different delimiter
2239 delim = "" if max(s + p + d) <= 9 else "x"
2240 arg_list.append(
2241 (
2242 "acc{}_st{}_pad{}_dilat{}".format(
2243 testGen.typeStr(a),
2244 delim.join([str(x) for x in s]),
2245 delim.join([str(x) for x in p]),
2246 delim.join([str(x) for x in d]),
2247 ),
2248 args_dict,
2249 )
2250 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002251 if (
2252 error_name
2253 and len(arg_list) >= TosaArgGen.MAX_CONV_ERROR_IFS
2254 ):
2255 # Found enough errors
2256 logger.debug(
2257 f"Skipping creating more conv error tests for {error_name}"
2258 )
2259 more_tests = False
Tai Lyf36f2562024-03-14 16:21:29 +00002260 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002261
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002262 arg_list = TosaArgGen._add_data_generators(
2263 testGen,
2264 opName,
evacha019c96eef2024-02-07 11:21:55 +00002265 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002266 dtypes[0],
2267 arg_list,
2268 error_name,
2269 )
2270 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002271 return arg_list
2272
2273 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002274 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002275
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002276 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002277 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002278
2279 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002280 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002281 elif error_name == ErrorIf.WrongInputType:
2282 # Pick some potentially correct output dtype if input type is incorrect
2283 accum_dtype = DType.INT32
2284 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002285 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002286
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002287 # Set up compliance info
2288 args_dict = {
2289 "acc_type": accum_dtype,
2290 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2291 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2292 "shape": shapeList[0],
2293 }
2294
2295 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2296
2297 arg_list = TosaArgGen._add_data_generators(
2298 testGen,
2299 opName,
evacha019c96eef2024-02-07 11:21:55 +00002300 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002301 input_dtype,
2302 arg_list,
2303 error_name,
2304 )
2305 # Return list of tuples: (arg_str, args_dict)
2306 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002307
2308 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002309 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002310 # Get valid accumulate type(s)
2311 if dtype == DType.INT8:
2312 accum_dtypes = [DType.INT32]
2313 elif dtype == DType.INT16:
2314 accum_dtypes = [DType.INT48]
2315 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002316 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002317 elif dtype == DType.BF16:
2318 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002319 elif dtype == DType.FP32:
2320 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002321 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2322 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002323 elif error_name is None:
2324 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2325
2326 if error_name == ErrorIf.WrongOutputType:
2327 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002328 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002329 elif error_name == ErrorIf.WrongInputType:
2330 # Pick some potentially correct output dtype if input type is incorrect
2331 accum_dtypes = [DType.INT32]
2332
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002333 # Set up compliance info
2334 args_dict = {
2335 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2336 # Set dot_products = N*H*W
2337 "dot_products": gtu.product(
2338 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2339 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002340 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002341 }
2342
2343 # Create arg tuple of string and dict
2344 arg_list = []
2345 for a in accum_dtypes:
2346 d = args_dict.copy()
2347 d["acc_type"] = a
2348 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002349
2350 arg_list = TosaArgGen._add_data_generators(
2351 testGen,
2352 opName,
evacha019c96eef2024-02-07 11:21:55 +00002353 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002354 dtype,
2355 arg_list,
2356 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002357 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002358 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002359 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002360
2361 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002362 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002363 arg_list = []
2364
Jeremy Johnson0c716862023-04-13 17:18:19 +01002365 if testGen.args.level8k and error_name is not None:
2366 # Don't produce negative large tests
2367 return arg_list
2368
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002369 ifm_shape = shapeList[0]
2370 filter_shape = shapeList[1]
2371
Tai Lyf36f2562024-03-14 16:21:29 +00002372 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2373
2374 if error_name == ErrorIf.WrongAccumulatorType:
2375 accum_dtypes = (
2376 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2377 )
James Ward8b390432022-08-12 20:48:56 +01002378
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002379 # Must be rank 4
2380 if error_name != ErrorIf.WrongRank:
2381 assert len(ifm_shape) == 4
2382 assert len(filter_shape) == 4
2383
Jeremy Johnson0c716862023-04-13 17:18:19 +01002384 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002385 # compliance size - KS
2386 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002387
Jeremy Johnson0c716862023-04-13 17:18:19 +01002388 if not testGen.args.level8k:
2389 # Generate comprehensive argument lists
2390 # - except for named errors, which use specific invalid value(s)
2391 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2392 if error_name == ErrorIf.PadLargerEqualKernel:
2393 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002394 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002395 else:
2396 p_vals = [
2397 x
2398 for x in range(
2399 smallest_padding_size, testGen.args.max_conv_padding + 1
2400 )
2401 ]
2402 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2403 if error_name == ErrorIf.StrideSmallerOne:
2404 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002405 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002406 else:
2407 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2408 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002409
Jeremy Johnson0c716862023-04-13 17:18:19 +01002410 if not error_name and testGen.args.oversize:
2411 # add some oversize argument values
2412 if max(ifm_shape) < 64:
2413 bigPadding = 9
2414 paddings.update(
2415 {
2416 x
2417 for x in itertools.product(
2418 *([[smallest_padding_size, bigPadding]] * 4)
2419 )
2420 }
2421 )
2422 bigStride = 8
2423 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2424
2425 # There are too many parameter combinations, so generate them sparsely,
2426 # very sparse for negative tests
2427 sparsity_factor = 2 if error_name else 10
2428 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2429 # If there are only a small number of tests, just select them all
2430 if sparsity < 13:
2431 sparsity = 1
2432 # To get a variety of parameter combinations sparsity should not be a
2433 # multiple of 2, 3 or 5
2434 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2435 sparsity += 1
2436 else:
2437 # Only test 8k levels boundaries
2438 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2439 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2440 bigPadding = bigKernel
2441
2442 pad_shape = [0] * (len(k_shape) * 2)
2443 stride_shape = [1] * len(k_shape)
2444 # The point at which input dimension combined with the stride will
2445 # create large output sizes!
2446 LARGE_SIZE = 2
2447 for idx in range(len(k_shape)):
2448 pad_offset = idx * 2
2449 if k_shape[idx] == bigKernel:
2450 # Set large stride
2451 stride_shape[idx] = bigKernel
2452 # Use negative output padding to reduce shape size
2453 pad_shape[pad_offset] = -(bigPadding - 1)
2454 if ifm_shape[idx + 1] > LARGE_SIZE:
2455 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2456 else:
2457 # The other dimension should be the bigKernel
2458 alt_idx = 1 - idx
2459 if (
2460 k_shape[alt_idx] == bigKernel
2461 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2462 ):
2463 # As the input is small, the large stride won't
2464 # affect the output so we can add some padding
2465 pad_shape[pad_offset + 1] = bigPadding
2466
2467 strides = {tuple(stride_shape)}
2468 paddings = {tuple(pad_shape)}
2469
2470 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002471 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002472
2473 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002474 for a in accum_dtypes:
2475 for s in sorted(list(strides)):
2476 for p in sorted(list(paddings)):
2477 if n % sparsity == 0:
2478 # Determine the output shape
2479 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2480 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2481 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002482
Tai Lyf36f2562024-03-14 16:21:29 +00002483 # N*OH*OW*OC
2484 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2485 args_dict = {
2486 "acc_type": a,
2487 "stride": s,
2488 "pad": p,
2489 "kernel": k_shape,
2490 "ks": k_size,
2491 "dot_products": dots,
2492 "shape": ifm_shape,
2493 "out_shape": os,
2494 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002495
Tai Lyf36f2562024-03-14 16:21:29 +00002496 # Support for larger values than 9 needs different delimiter
2497 delim = "" if max(s + p) <= 9 else "x"
2498 arg_list.append(
2499 (
2500 "acc{}_st{}_pad{}_os{}".format(
2501 testGen.typeStr(a),
2502 delim.join([str(x) for x in s]),
2503 delim.join([str(x) for x in p]),
2504 "x".join([str(x) for x in os]),
2505 ),
2506 args_dict,
2507 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002508 )
Tai Lyf36f2562024-03-14 16:21:29 +00002509 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002510
Jeremy Johnson95a67102024-01-10 14:16:39 +00002511 arg_list = TosaArgGen._add_data_generators(
2512 testGen,
2513 opName,
evacha019c96eef2024-02-07 11:21:55 +00002514 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002515 dtypes[0],
2516 arg_list,
2517 error_name,
2518 )
2519 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002520 return arg_list
2521
2522 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002523 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002524 rank = len(shapeList[0])
2525
2526 # Exhaustively test combinations of padding on each side of each dimension
2527 # - the range of padding values is defined by pad_min and pad_max
2528 # - for padding >9, the name format needs to be more distinctive
2529 pad_min, pad_max = 0, 1
2530 pad_values = [x for x in range(pad_min, pad_max + 1)]
2531 if error_name == ErrorIf.PadSmallerZero:
2532 pad_values = [x for x in range(-2, 0)]
2533 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2534 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2535
2536 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002537 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002538 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002539 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002540 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002541 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002543 assert error_name == ErrorIf.WrongInputType
2544 pad_const_int = 0
2545 pad_const_fp = 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002546
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002547 list_shape_pad_values = list(shape_pad_values)
2548 # If we are producing tests for rank 6 or greater use sparsity
2549 if len(list_shape_pad_values) > 1024:
2550 sparsity_factor = 2 if error_name else 120
2551 sparsity = TosaArgGen._calculate_sparsity(
2552 len(list_shape_pad_values), sparsity_factor
2553 )
2554 else:
2555 sparsity = 1
2556
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002557 # Build arg list
2558 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002559 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002560 paddings = list(paddings)
2561 args_valid = True
2562
2563 if error_name == ErrorIf.PadSmallerZero:
2564 # Prevent negative output shapes while ensuring still testing for negative padding
2565 for i in range(rank):
2566 dim_after_padding = (
2567 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2568 )
2569 if dim_after_padding < 1:
2570 paddings[i] = (0, 0)
2571 if all([p > -1 for p in paddings[i]]):
2572 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002573 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002574 name = "pad"
2575 for r in range(rank):
2576 before, after = paddings[r]
2577 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002578 args_dict = {
2579 "pad": np.array(paddings),
2580 "pad_const_int": pad_const_int,
2581 "pad_const_fp": pad_const_fp,
2582 }
2583 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002584
2585 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002586 logger.debug(
2587 f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
2588 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002589
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002590 arg_list = TosaArgGen._add_data_generators(
2591 testGen,
2592 opName,
evacha019c96eef2024-02-07 11:21:55 +00002593 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002594 dtype,
2595 arg_list,
2596 error_name,
2597 )
2598
2599 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002600 return arg_list
2601
2602 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002603 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002604 arg_list = []
2605
2606 shape = shapeList[0]
2607 if error_name != ErrorIf.WrongRank:
2608 assert len(shape) == 4
2609
Jeremy Johnson0c716862023-04-13 17:18:19 +01002610 test_level8k = testGen.args.level8k and error_name is None
2611
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002612 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002613 startKernel = 2
2614 startPad = 0
2615 if not test_level8k:
2616 # Generate comprehensive argument lists
2617 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2618 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2619 # Stride must be greater than 1 to force non-integer error
2620 s_vals = [
2621 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2622 ]
2623 strides = {x for x in itertools.product(*([s_vals] * 2))}
2624 k_vals = [
2625 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2626 ]
2627 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2628 max_dim_size = None
2629 else:
2630 # Only test 8k levels
2631 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2632 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2633 strides = {(1, bigStride), (bigStride, 4)}
2634 kernels = {(1, bigKernel), (bigKernel, 3)}
2635 paddings = set()
2636 for s in sorted(list(strides)):
2637 for k in sorted(list(kernels)):
2638 padding = []
2639 for idx in range(len(k)):
2640 total_padding = s[idx] - shape[idx + 1] + k[idx]
2641 while total_padding < 0:
2642 # Must meet: shape + padding > kernel
2643 total_padding += s[idx]
2644 if total_padding < k[idx]:
2645 padding.extend([0, total_padding])
2646 else:
2647 # Note this may produce padding >= k[idx] which is not
2648 # allowed - but will be ignored in the creation loop below
2649 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2650 paddings.add(tuple(padding))
2651 # Create a limit for the output dimensions size
2652 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002653
James Ward8b390432022-08-12 20:48:56 +01002654 if opName == "max_pool2d":
2655 accum_dtypes = [None] # max_pool has no accumulate dtype
2656 elif dtype == DType.INT8 or dtype == DType.INT16:
2657 accum_dtypes = [DType.INT32]
2658 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002659 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002660 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002661 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002662 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2663 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002664 elif error_name is None:
2665 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2666 else:
2667 # Set to something for the ErrorIf case which has
2668 # incorrect input data-type
2669 accum_dtypes = [DType.INT32]
2670
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002671 if error_name == ErrorIf.WrongAccumulatorType:
2672 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2673
Jeremy Johnson0c716862023-04-13 17:18:19 +01002674 if not test_level8k:
2675 if testGen.args.oversize:
2676 # add some oversize argument values
2677 bigStride = 7
2678 bigKernel = 9
2679 strides.update(
2680 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002681 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002682 kernels.update(
2683 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2684 )
2685 if max(shape) < 64:
2686 # padding must be less than the kernel size
2687 bigPadding = bigKernel - 1
2688 paddings.update(
2689 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2690 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002691
Jeremy Johnson0c716862023-04-13 17:18:19 +01002692 # There are too many parameter combinations, so generate them sparsely,
2693 # very sparse for negative tests
2694 sparsity_factor = 2 if error_name else 500
2695 sparsity = (
2696 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2697 )
2698 else:
2699 # We have already limited test output combinations for 8k tests
2700 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002701
James Ward8b390432022-08-12 20:48:56 +01002702 arg_str = (
2703 "acc{}_st{}_kern{}_pad{}"
2704 if accum_dtypes[0] is not None
2705 else "st{}_kern{}_pad{}"
2706 )
2707
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002708 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002709 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002710 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002711
2712 # Support for larger values than 9 needs different delimiter
2713 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002714 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002715 delim.join([str(x) for x in stride]),
2716 delim.join([str(x) for x in kern]),
2717 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002718 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002719 args_dict = {
2720 "stride": stride,
2721 "pad": pad,
2722 "kernel": kern,
2723 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002724 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002725 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2726 }
James Ward8b390432022-08-12 20:48:56 +01002727
2728 if accum is not None:
2729 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002730 args_dict["acc_type"] = accum
2731 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002732
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002733 n = 0
James Ward8b390432022-08-12 20:48:56 +01002734 for a in accum_dtypes:
2735 for s in sorted(list(strides)):
2736 for p in sorted(list(paddings)):
2737 for k in sorted(list(kernels)):
2738 if error_name in [
2739 ErrorIf.StrideSmallerOne,
2740 ErrorIf.KernelSmallerOne,
2741 ErrorIf.PadSmallerZero,
2742 ErrorIf.PadLargerEqualKernel,
2743 ]:
2744 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002745 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002746 )
James Ward8b390432022-08-12 20:48:56 +01002747 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002748 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002749 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002750 )
James Ward8b390432022-08-12 20:48:56 +01002751 elif (
2752 n % sparsity == 0
2753 # padding must not exceed the kernel size
2754 and p[0] < k[0]
2755 and p[1] < k[0]
2756 and p[2] < k[1]
2757 and p[3] < k[1]
2758 # the padded shape must exceed the kernel size
2759 and (shape[1] + p[0] + p[1]) > k[0]
2760 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002761 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002762 partial_h = shape[1] + p[0] + p[1] - k[0]
2763 partial_w = shape[2] + p[2] + p[3] - k[1]
2764 remainder_h = partial_h % s[0]
2765 remainder_w = partial_w % s[1]
2766 output_h = partial_h // s[0] + 1
2767 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002768 logger.debug(
2769 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2770 )
James Ward8b390432022-08-12 20:48:56 +01002771 if (
2772 # the parameters must produce integer exact output
2773 error_name != ErrorIf.PoolingOutputShapeNonInteger
2774 and remainder_h == 0
2775 and remainder_w == 0
2776 ) or (
2777 error_name == ErrorIf.PoolingOutputShapeNonInteger
2778 and (remainder_h != 0 or remainder_w != 0)
2779 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002780 if (
2781 max_dim_size is not None
2782 and max(output_h, output_w) > max_dim_size
2783 ):
2784 # Test will consume too much memory - skip it
2785 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002786 # Dot products = N*OH*OW*C
2787 dp = gtu.product(
2788 (shape[0], output_h, output_w, shape[3])
2789 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002790 arg_list.append(
2791 get_arg_list_element(a, s, p, k, dp, shape)
2792 )
James Ward8b390432022-08-12 20:48:56 +01002793 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002794
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002795 # Now add data generator types
2796 arg_list = TosaArgGen._add_data_generators(
2797 testGen,
2798 opName,
evacha019c96eef2024-02-07 11:21:55 +00002799 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002800 dtype,
2801 arg_list,
2802 error_name,
2803 )
2804
2805 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002806 return arg_list
2807
2808 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002809 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002810 arg_list = []
2811
2812 # Enumerate the output types here
2813 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002814 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002815 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002816 dtypeList = [
2817 DType.BOOL,
2818 DType.INT16,
2819 DType.INT32,
2820 DType.FP16,
2821 DType.BF16,
2822 DType.FP32,
2823 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002824 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002825 dtypeList = [
2826 DType.BOOL,
2827 DType.INT8,
2828 DType.INT32,
2829 DType.FP16,
2830 DType.BF16,
2831 DType.FP32,
2832 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002833 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002834 dtypeList = [
2835 DType.BOOL,
2836 DType.INT8,
2837 DType.INT16,
2838 DType.FP16,
2839 DType.BF16,
2840 DType.FP32,
2841 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002842 elif inDtype == DType.BOOL:
2843 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002844 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002845 dtypeList = [
2846 DType.INT8,
2847 DType.INT16,
2848 DType.INT32,
2849 DType.FP32,
2850 DType.FP8E4M3,
2851 DType.FP8E5M2,
2852 ]
James Ward24dbc422022-10-19 12:20:31 +01002853 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002854 dtypeList = [
2855 DType.INT8,
2856 DType.INT16,
2857 DType.INT32,
2858 DType.FP32,
2859 DType.FP8E4M3,
2860 DType.FP8E5M2,
2861 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002862 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002863 dtypeList = [
2864 DType.INT8,
2865 DType.INT16,
2866 DType.INT32,
2867 DType.FP16,
2868 DType.BF16,
2869 DType.FP8E4M3,
2870 DType.FP8E5M2,
2871 ]
2872 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2873 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002874 elif error_name == ErrorIf.WrongInputType:
2875 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002876 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002877 else:
2878 raise Exception("Unexpected input dtype: {}".format(inDtype))
2879
2880 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002881 arg_list.append(
2882 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2883 )
2884
2885 # Now add data generator types
2886 arg_list = TosaArgGen._add_data_generators(
2887 testGen,
2888 opName,
evacha019c96eef2024-02-07 11:21:55 +00002889 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002890 dtype,
2891 arg_list,
2892 error_name,
2893 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002894
2895 return arg_list
2896
2897 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002898 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002899 arg_list = []
2900
2901 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002902 for outDtype in [
2903 DType.UINT8,
2904 DType.INT8,
2905 DType.INT16,
2906 DType.INT32,
2907 DType.UINT16,
2908 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002909 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002910 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002911 and error_name == ErrorIf.OutputZeroPointNotZero
2912 ):
2913 continue
2914 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002915 outDtype != DType.UINT16
2916 and error_name == ErrorIf.U16OutputZeroPointNotValid
2917 ) or (
2918 inDtype != DType.UINT16
2919 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002920 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002921 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002922 continue
2923 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002924 inDtype == DType.UINT8
2925 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002926 and error_name != ErrorIf.WrongOutputType
2927 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002928 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2929 continue
2930 if (
2931 inDtype not in [DType.INT8, DType.INT16]
2932 and outDtype == DType.UINT8
2933 and error_name != ErrorIf.WrongOutputType
2934 ):
2935 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2936 continue
2937 if (
2938 inDtype == DType.UINT16
2939 and outDtype != DType.INT16
2940 and error_name != ErrorIf.WrongOutputType
2941 ):
2942 # The only output dtype for UINT16 is INT16, skip all others
2943 continue
2944 if (
2945 inDtype != DType.INT16
2946 and outDtype == DType.UINT16
2947 and error_name != ErrorIf.WrongOutputType
2948 ):
2949 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002950 continue
2951 if (
2952 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002953 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002954 ):
2955 continue
2956
2957 for scale32 in [False, True]:
2958 if error_name == ErrorIf.ScaleTrue and not scale32:
2959 continue
2960 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2961 continue
2962 for double_round in [False, True]:
2963 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2964 continue
2965 for per_channel in [False, True]:
2966
2967 if (
2968 inDtype == DType.INT48
2969 and scale32
2970 and error_name != ErrorIf.ScaleTrue
2971 ):
2972 # Illegal condition. Must be scale32=False
2973 continue
2974 if (
2975 double_round
2976 and not scale32
2977 and error_name != ErrorIf.ScaleNotTrue
2978 ):
2979 # Illegal condition. ERROR_IF(!scale32 && double_round)
2980 continue
2981
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002982 if per_channel:
2983 nc = shapeList[0][-1]
2984 else:
2985 nc = 1
2986
2987 in_type_width = gtu.dtypeWidth(inDtype)
2988 out_type_width = gtu.dtypeWidth(outDtype)
2989
2990 # Calculate scale based on:
2991 # scale = a *(2^output_width)/(2^input_width))
2992
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002993 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002994 scale_arr = a * np.float32(
2995 (1 << out_type_width) / (1 << in_type_width)
2996 )
2997
2998 if scale32:
2999 # Cap the scaling at 2^31 - 1 for scale32
3000 scale_arr = np.clip(
3001 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
3002 )
3003 else:
3004 # Cap the scaling at 2^15 - 1 for scale16
3005 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3006
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003007 logger.debug(
3008 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
3009 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003010
3011 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3012 shift_arr = np.int32(np.zeros(shape=[nc]))
3013 for i in range(nc):
3014 (
3015 multiplier_arr[i],
3016 shift_arr[i],
3017 ) = TosaQuantGen.computeMultiplierAndShift(
3018 scale_arr[i], scale32
3019 )
3020
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003021 arg_list.append(
3022 (
3023 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01003024 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003025 int(scale32),
3026 int(double_round),
3027 int(per_channel),
3028 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00003029 {
3030 "output_dtype": outDtype,
3031 "scale": scale32,
3032 "double_round": double_round,
3033 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003034 "multiplier": multiplier_arr,
3035 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003036 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003037 )
3038 )
3039
Jeremy Johnson587cc842024-02-08 11:45:44 +00003040 arg_list = TosaArgGen._add_data_generators(
3041 testGen,
3042 opName,
evacha019c96eef2024-02-07 11:21:55 +00003043 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003044 inDtype,
3045 arg_list,
3046 error_name,
3047 )
3048 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003049 return arg_list
3050
3051 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003052 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003053 arg_list = []
3054
3055 if dtype is DType.INT32:
3056 for p in range(testGen.args.num_rand_permutations):
3057
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003058 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003059 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003060 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003061 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003062
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003063 arg_list = TosaArgGen._add_data_generators(
3064 testGen,
3065 opName,
evacha019c96eef2024-02-07 11:21:55 +00003066 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003067 dtype,
3068 arg_list,
3069 error_name,
3070 )
3071 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003072 return arg_list
3073
3074 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003075 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003076 arg_list = []
3077
Jeremy Johnson587cc842024-02-08 11:45:44 +00003078 for round in (True, False):
3079 args_dict = {
3080 "round": round,
3081 }
3082 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003083
Jeremy Johnson587cc842024-02-08 11:45:44 +00003084 arg_list = TosaArgGen._add_data_generators(
3085 testGen,
3086 opName,
evacha019c96eef2024-02-07 11:21:55 +00003087 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003088 dtype,
3089 arg_list,
3090 error_name,
3091 )
3092 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003093 return arg_list
3094
Luke Hutton57287132023-02-06 14:54:18 +00003095 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003096 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003097 arg_list = []
3098
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003099 shape = shapeList[0]
3100 dot_products = gtu.product(shape)
3101 ks = 2 * shape[1] * shape[2] # 2*H*W
3102 for inverse in (True, False):
3103 args_dict = {
3104 "dot_products": dot_products,
3105 "shape": shape,
3106 "ks": ks,
3107 "acc_type": dtype,
3108 "inverse": inverse,
3109 }
3110 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003111
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003112 arg_list = TosaArgGen._add_data_generators(
3113 testGen,
3114 opName,
evacha019c96eef2024-02-07 11:21:55 +00003115 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003116 dtype,
3117 arg_list,
3118 error_name,
3119 )
3120 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003121 return arg_list
3122
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003123 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003124 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003125 arg_list = []
3126
3127 shape = shapeList[0]
3128 dot_products = gtu.product(shape)
3129 ks = shape[1] * shape[2] # H*W
3130 args_dict = {
3131 "dot_products": dot_products,
3132 "shape": shape,
3133 "ks": ks,
3134 "acc_type": dtype,
3135 }
3136 arg_list.append(("", args_dict))
3137
3138 arg_list = TosaArgGen._add_data_generators(
3139 testGen,
3140 opName,
evacha019c96eef2024-02-07 11:21:55 +00003141 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003142 dtype,
3143 arg_list,
3144 error_name,
3145 )
3146 # Return list of tuples: (arg_str, args_dict)
3147 return arg_list
3148
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003149 # Helper function for reshape. Gets some factors of a larger number.
3150 @staticmethod
3151 def getFactors(val, start=1):
3152 factors = []
3153
3154 for i in range(start, int(np.sqrt(val)) + 1):
3155 if (val % i) == 0:
3156 factors.append(i)
3157
3158 return factors
3159
3160 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003161 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003162 arg_list = []
3163
3164 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003165 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003166 factors = TosaArgGen.getFactors(totalElements)
3167
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003168 # Find new shapes up to the number of permutations asked for
3169 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003170 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003171 # Rank from 1 to MAX_TENSOR_RANK
3172 newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003173 if len(factors) < newRank:
3174 continue
3175
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003176 # escape_counter limits the generation of new shapes to a reasonable time
3177 for escape_counter in range(100):
3178
3179 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003180 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003181 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003182 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003183 for i in range(1, newRank):
3184 # pick rank-1 factors
3185 newShape.append(shuffledFactors[0])
3186 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003187 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003188 TosaArgGen.getFactors(remainingElements)
3189 )
3190 newShape.append(remainingElements)
3191
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003192 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003193 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003194 for name, args_dict in arg_list:
3195 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003196 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003197 break
3198
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003199 if not duplicate:
3200 outShape = "x".join([str(x) for x in newShape])
3201 arg_list.append(
3202 (
3203 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3204 {"new_shape": newShape},
3205 )
3206 )
3207 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003208 break
3209
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003210 # Now add data generator types
3211 arg_list = TosaArgGen._add_data_generators(
3212 testGen,
3213 opName,
evacha019c96eef2024-02-07 11:21:55 +00003214 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003215 dtype,
3216 arg_list,
3217 error_name,
3218 )
3219
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 return arg_list
3221
3222 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003223 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003224 arg_list = []
3225
3226 ifm_shape = shapeList[0]
3227
3228 if error_name == ErrorIf.IndexOutsideBounds:
3229 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3230 incorrect_small_index = range(-len(ifm_shape), 0)
3231 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3232 permutations.extend(
3233 [p for p in itertools.permutations(incorrect_small_index)]
3234 )
3235 elif error_name == ErrorIf.IndexUsedTwice:
3236 # Create list with a duplicated index
3237 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003238 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003239 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3240 permutations = [p for p in itertools.permutations(perm_range)]
3241
3242 else:
3243 # Get all permutations
3244 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3245
3246 # Limit to possible permutations from shape dimension or argument setting
3247 limit = min(len(permutations), testGen.args.num_rand_permutations)
3248
3249 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003250 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003251
3252 # Create list of required amount of permutations
3253 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003254 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 for p in range(limit)
3256 ]
evacha0198477222024-01-26 12:25:32 +00003257 # Now add data generator types
3258 arg_list = TosaArgGen._add_data_generators(
3259 testGen,
3260 opName,
evacha019c96eef2024-02-07 11:21:55 +00003261 shapeList,
evacha0198477222024-01-26 12:25:32 +00003262 dtype,
3263 arg_list,
3264 error_name,
3265 )
3266 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003267 return arg_list
3268
3269 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003270 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003271 arg_list = []
3272
3273 ifm_shape = shapeList[0]
3274 rank = len(ifm_shape)
3275
3276 for p in range(testGen.args.num_rand_permutations):
3277 start = []
3278 size = []
3279
3280 valid = True
3281
3282 for i in range(rank):
3283 if ifm_shape[i] > 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003284 start.append(rng.randInt(0, ifm_shape[i]))
3285 size.append(rng.randInt(0, ifm_shape[i] - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003286
3287 # Invalid slice size?
3288 if size[i] == 0:
3289 valid = False
3290 else:
3291 start.append(0)
3292 size.append(1)
3293
3294 if valid:
3295 # If ERROR_IF test required then incorrect start, size will be returned
3296 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003297 rng, error_name, ifm_shape, start, size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003298 )
evacha017f7d4252024-01-24 12:08:09 +00003299 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3300 # Now add data generator types
3301 arg_list = TosaArgGen._add_data_generators(
3302 testGen,
3303 opName,
evacha019c96eef2024-02-07 11:21:55 +00003304 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003305 dtype,
3306 arg_list,
3307 error_name,
3308 )
3309 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003310 return arg_list
3311
3312 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003313 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003314 arg_list = []
3315
3316 ifm_shape = shapeList[0]
3317 rank = len(ifm_shape)
3318
3319 for p in range(testGen.args.num_rand_permutations):
3320
3321 # Pick a few random, but small multiple values
3322 # because otherwise this has a tendency to generate
3323 # enormous tensors
3324 multiples = []
3325 for i in range(rank):
3326 if ifm_shape[i] > 1000:
3327 # Multiple of 1 if ifm_shape dimension is large to reduce
3328 # tensor size
3329 multiples.append(1)
3330 elif max(ifm_shape) > 1000:
3331 multiples.append(2)
3332 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003333 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003334 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003335
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003336 # Now add data generator types
3337 arg_list = TosaArgGen._add_data_generators(
3338 testGen,
3339 opName,
evacha019c96eef2024-02-07 11:21:55 +00003340 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003341 dtype,
3342 arg_list,
3343 error_name,
3344 )
3345 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003346 return arg_list
3347
3348 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003349 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003350 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003351 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003352
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003353 def get_aspect_ratio_resize_params():
3354 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003355 aspect_ratio = rng.choice(common_aspect_ratios)
3356 invert = rng.choice((False, True))
3357 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003358
3359 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3360 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3361 scale_y_d = scale_x_d = 1
3362 offset_x = offset_y = 0
3363
3364 if letterbox:
3365 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003366 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003367 border_x = 0
3368 else:
3369 # Pillarboxing
3370 border_y = 0
3371 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003372 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003373
3374 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3375 offset = (offset_y, offset_x)
3376 border = (border_y, border_x)
3377
3378 return scale, offset, border
3379
3380 def get_upscale_downscale_params():
3381 valid_params = False
3382 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003383 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003384
3385 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003386 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003387
3388 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003389 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003390 scale_x_d = scale_y_d = 1
3391 scale_x_n = scale_y_n = (
3392 1 << shift if origin_sampling else 2 << shift
3393 )
3394 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3395 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3396 else:
3397 scale_x_n = 1
3398 scale_y_n = 1
3399
3400 # Return list of valid scale_*_d values (max value 4) given input dim shape
3401 def get_valid_denom(ifm_dim):
3402 return [x for x in range(1, 5) if ifm_dim % x == 1]
3403
3404 # Generate list of valid downscale values and choose one randomly
3405 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3406 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3407
3408 if not valid_scale_y_ds and not valid_scale_x_ds:
3409 # Bad parameters, skip
3410 continue
3411
3412 if not valid_scale_y_ds:
3413 scale_y_d = 1
3414 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003415 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003416
3417 if not valid_scale_x_ds:
3418 scale_x_d = 1
3419 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003420 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003421
3422 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003423 offset_y = rng.randInt(0, 16 * scale_y_n)
3424 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003425 valid_params = True
3426
3427 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3428 offset = (offset_y, offset_x)
3429 border = (border_y, border_x)
3430 return scale, offset, border
3431
3432 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003433 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3434 scale = scale_n / scale_d
3435 if scale > max_scale:
3436 factor = scale / max_scale
3437 new_scale_d = math.ceil(scale_d * factor)
3438 assert scale_n / new_scale_d <= max_scale
3439 scale_d = new_scale_d
3440 return scale_d
3441
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003442 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003443 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3444 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003445
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003446 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3447 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003448
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003449 scale_y_d = fix_scale_to_max_scale(
3450 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3451 )
3452 scale_x_d = fix_scale_to_max_scale(
3453 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3454 )
3455
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003456 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003457 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3458 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3459 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3460 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003461
3462 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3463 offset = (offset_y, offset_x)
3464 border = (border_y, border_x)
3465 return scale, offset, border
3466
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003467 def get_level_8k_params():
3468 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003469 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003470 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3471 )
3472 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3473 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003474 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003475 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003476 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003477 if switch:
3478 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3479 else:
3480 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3481
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003482 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3483 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003484 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003485 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3486 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003487 border = (border_y, border_x)
3488 return scale, offset, border
3489
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003490 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003491 # Exclude illegal {mode, type} configurations. Pick legal output types
3492 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3493 outputDTypeList = [DType.INT8]
3494 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3495 outputDTypeList = [DType.INT16]
3496 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3497 outputDTypeList = [DType.INT32]
3498 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3499 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003500 elif dtype == DType.FP16:
3501 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003502 elif dtype == DType.BF16:
3503 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003504 elif dtype == DType.FP32:
3505 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003506 elif dtype == DType.FP8E4M3:
3507 outputDTypeList = [DType.FP8E4M3]
3508 elif dtype == DType.FP8E5M2:
3509 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003510 elif error_name == ErrorIf.WrongInputType:
3511 # If an incorrect input type is used then we set a 'correct'
3512 # output type to avoid other errors
3513 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3514 else:
3515 continue
3516
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003517 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3518
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003520 perm = 0
3521 while perm < testGen.args.num_rand_permutations:
3522 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003523 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003524 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003525 (
3526 get_rand_params,
3527 get_upscale_downscale_params,
3528 get_aspect_ratio_resize_params,
3529 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003530 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003531 scale, offset, border = _rnd_param_fn()
3532 else:
3533 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003534
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003535 # Expand params for bounds-checking
3536 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3537 (offset_y, offset_x) = offset
3538 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003539
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003540 # Make sure output dimensions OH and OW are integers
3541 partial_output_y = (
3542 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3543 )
3544 partial_output_x = (
3545 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3546 )
3547 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003548 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003549 if (
3550 partial_output_y % scale_y_d == 0
3551 and partial_output_x % scale_x_d == 0
3552 ):
3553 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003554 if perm > 0:
3555 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003556 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003557 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003558 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003559 while partial_output_y % scale_y_d != 0:
3560 scale_y_d -= 1
3561 while partial_output_x % scale_x_d != 0:
3562 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003563 # Make sure we are still within max scaling
3564 if (
3565 scale_y_n / scale_y_d
3566 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3567 scale_x_n / scale_x_d
3568 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3569 # Skip the test as it is using too large a scaling factor
3570 if perm > 0:
3571 perm += 1
3572 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003573
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003574 output_y = partial_output_y // scale_y_d + 1
3575 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003576
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003577 if (
3578 output_y >= testGen.args.max_resize_output_dim
3579 or output_x >= testGen.args.max_resize_output_dim
3580 ) and error_name is None:
3581 # Skip positive test if output dim will be too high
3582 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003583 if not testGen.args.level8k or perm > 0:
3584 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003585 continue
3586
3587 if (
3588 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003589 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003590 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003591 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003592 ):
3593 # Output dimensions out of scope
3594 if error_name is not None and perm > 0:
3595 # As long as we have one ERROR_IF test, don't worry
3596 # about creating all the other permutations
3597 perm += 1
3598 continue
3599
3600 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3601 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003602 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003603 and output_y - scale_y_d < 1
3604 )
3605 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003606 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003607 and output_x - scale_x_d < 1
3608 )
3609 ):
3610 # Can't create a negative test with these params as it
3611 # will create invalid output size
3612 if perm > 0:
3613 perm += 1
3614 continue
3615
3616 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3617 offset = [offset_y, offset_x]
3618 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003619
3620 # Common for all data types
3621 if error_name is not None:
3622 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003623 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003624 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003625 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003626 outputDTypeNew,
3627 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003628 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003629 error_name,
3630 mode,
3631 dtype,
3632 shapeList,
3633 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003634 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003635 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003636 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 )
3638 else:
3639 outputDTypeNew = outputDType
3640
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003641 arg_to_append = (
3642 arg_str.format(
3643 "N" if mode == ResizeMode.NEAREST else "B",
3644 testGen.typeStr(outputDTypeNew),
3645 scale[0],
3646 scale[1],
3647 scale[2],
3648 scale[3],
3649 offset[0],
3650 offset[1],
3651 border[0],
3652 border[1],
3653 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003654 {
3655 "mode": mode,
3656 "scale": scale,
3657 "offset": offset,
3658 "border": border,
3659 "output_dtype": outputDTypeNew,
3660 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003661 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003662 if arg_to_append in arg_list:
3663 # Skip already generated test params
3664 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003665
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003666 # Valid permutation
3667 perm += 1
3668 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003669
3670 # Now add data generator types
3671 arg_list = TosaArgGen._add_data_generators(
3672 testGen,
3673 opName,
evacha019c96eef2024-02-07 11:21:55 +00003674 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003675 dtype,
3676 arg_list,
3677 error_name,
3678 )
3679 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003680 return arg_list
3681
3682 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003683 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003684 arg_list = []
3685
3686 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003687 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003688 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003689 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003690 # Make sure all slopes are within REQUIRE min/max 16-bit int
3691 for idx in range(len(table) - 1):
3692 slope = table[idx + 1] - table[idx]
3693 # Alter the next table entry to force the slope to be ok
3694 if slope > 32767:
3695 table[idx + 1] -= slope - 32767
3696 if slope < -32768:
3697 table[idx + 1] -= slope + 32768
3698 slope = table[idx + 1] - table[idx]
3699 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003700 arg_list.append(
3701 (
3702 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003703 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 )
3705 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003706 # Now add data generator types
3707 arg_list = TosaArgGen._add_data_generators(
3708 testGen,
3709 opName,
evacha019c96eef2024-02-07 11:21:55 +00003710 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003711 dtype,
3712 arg_list,
3713 error_name,
3714 )
3715 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003716 return arg_list
3717
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003718 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003719 # CondIf generates the condition values here.
3720 # Convert to tensors in the build function, along with the
3721 # then and else blocks
3722 arg_list = []
3723
3724 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003725 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003726
Jeremy Johnson587cc842024-02-08 11:45:44 +00003727 # Now add data generator types
3728 arg_list = TosaArgGen._add_data_generators(
3729 testGen,
3730 opName,
evacha019c96eef2024-02-07 11:21:55 +00003731 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003732 dtype,
3733 arg_list,
3734 error_name,
3735 )
3736 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003737 return arg_list
3738
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003739 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003740 # While loop: 0 iterations, 1, more than 1
3741 arg_list = []
3742
Jeremy Johnson587cc842024-02-08 11:45:44 +00003743 for iterations in [0, 1, 4]:
3744 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003745
Jeremy Johnson587cc842024-02-08 11:45:44 +00003746 # Now add data generator types
3747 arg_list = TosaArgGen._add_data_generators(
3748 testGen,
3749 opName,
evacha019c96eef2024-02-07 11:21:55 +00003750 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003751 dtype,
3752 arg_list,
3753 error_name,
3754 )
3755 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003756 return arg_list