blob: f9499b5f820eee6dbe785282c8bf7e9ec86ee0ae [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
262 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100263 shape_list = []
264
265 # Choose one of the inputs to broadcast
266 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100267 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
268 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000269
Jeremy Johnson0a042992024-02-28 13:20:05 +0000270 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100271 shape_bcast = shape.copy()
272
Jerry Ge135c9552023-05-23 20:59:32 +0000273 # To test broadcasting, the chosen fuzz index dimension should not be 1
274 if shape_bcast[fuzz_idx] == 1:
275 shape_bcast[fuzz_idx] += 1
276
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100277 # If the chosen input, pick a random index to broadcast
278 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000279 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100280 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100281 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100282 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100283 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 )
285 if rank != 1:
286 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100287 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000289 elif error_name == ErrorIf.BroadcastShapesMismatch:
290 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 else:
292 shape_bcast[fuzz_idx] = 1
293
294 shape_list.append(shape_bcast)
295
296 return shape_list
297
298 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100299 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000300 pl, const = op["operands"]
301 num_shapes = pl + const
302 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100303 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000304 )
305
306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100307 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000308 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 shape_list = TosaTensorGen._get_broadcast_shapes(
310 testGen, rng, 2, rank, error_name
311 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000312 # Add a single dimension tensor for shift
313 shape_list.append([1])
314 return shape_list
315
316 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100317 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318 pl, const = op["operands"]
319
320 if error_name != ErrorIf.WrongRank:
321 assert rank == 4
322
323 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000325 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326
327 # Constrict the overall size of the shape when creating ERROR_IF tests
328 if error_name:
329 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
330 ifm_shape, max_dim=24, max_items=10000
331 )
332
333 # Get the filter height/width from the operator parameters
334 filter_hw = op["filter"]
335
336 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100337 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # The filter dimensions are OHWI
340 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
341
Jeremy Johnson5e36bde2024-03-14 16:56:10 +0000342 # The bias is OC or 1 if broadcastable
343 try:
344 if op["broadcastable_bias"]:
345 if rng.choice([True, False]):
346 ofm_depth = 1
347 except KeyError:
348 pass
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 bias_shape = np.asarray([ofm_depth])
350
351 return [ifm_shape, filter_shape, bias_shape]
352
353 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100354 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100355 pl, const = op["operands"]
356
357 if error_name != ErrorIf.WrongRank:
358 assert rank == 5
359
360 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100361 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000362 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # Constrict the overall size of the shape when creating ERROR_IF tests
365 if error_name:
366 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
367 ifm_shape, max_dim=24, max_items=10000
368 )
369
370 # Get the filter depth/height/width from the operator parameters
371 filter_dhw = op["filter"]
372
373 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100374 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100375
376 # The filter dimensions are ODHWI
377 filter_shape = np.asarray(
378 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
379 )
380
381 # The bias is OC
382 bias_shape = np.asarray([ofm_channel])
383
384 return [ifm_shape, filter_shape, bias_shape]
385
386 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100387 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100388 pl, const = op["operands"]
389
390 if error_name != ErrorIf.WrongRank:
391 assert rank == 4
392
393 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100394 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000395 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100396
397 # Constrict the overall size of the shape when creating ERROR_IF tests
398 if error_name:
399 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
400 ifm_shape, max_dim=24, max_items=10000
401 )
402
403 # Get the filter height/width from the operator parameters
404 filter_hw = op["filter"]
405
406 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100407 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100408
409 # The filter dimensions are OHWI
410 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
411
412 # The bias is OC
413 bias_shape = np.asarray([ofm_depth])
414
415 return [ifm_shape, filter_shape, bias_shape]
416
417 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100418 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100419 pl, const = op["operands"]
420
421 if error_name != ErrorIf.WrongRank:
422 assert rank == 4
423 assert pl == 1 and const == 2
424
425 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100426 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000427 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100428
429 # Constrict the overall size of the shape when creating ERROR_IF tests
430 if error_name:
431 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
432 ifm_shape, max_dim=24, max_items=10000
433 )
434
435 # Get the filter height/width from the operator parameters
436 # Filter is KH, HW, C, M
437 filter_hw = op["filter"]
438
439 # Generate a random OFM depth, but don't let it get too big because
440 # the output depth is M * C
441 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100442 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100443 ) + 1
444
445 # The filter dimensions are HWCM
446 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
447
448 # The bias is M * C
449 bias_shape = np.asarray([ifm_shape[3] * filter_m])
450
451 return [ifm_shape, filter_shape, bias_shape]
452
453 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100454 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000455 pl, const = op["operands"]
456
457 if error_name != ErrorIf.WrongRank:
458 assert rank == 3
459 assert pl == 2 and const == 0
460
461 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100462 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000463
464 # Select nearest lower power of two from input height and width
465 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
466 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
467
468 # Constrict the overall size of the shape when creating ERROR_IF tests
469 if error_name:
470 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
471
472 # Generate an invalid kernel that is not a power of two
473 if error_name == ErrorIf.KernelNotPowerOfTwo:
474 inc_h = 2 if ifm_shape[1] == 1 else 1
475 inc_w = 2 if ifm_shape[2] == 1 else 1
476 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100477 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000478 ifm_shape[1] += selected_inc[0]
479 ifm_shape[2] += selected_inc[1]
480
481 ifm_shape = testGen.constrictBatchSize(ifm_shape)
482
483 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
484 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100485 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000486 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100487 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000488 ifm_shapes[modify_shape][modify_dim] *= 2
489
490 return [ifm_shapes[0], ifm_shapes[1]]
491
492 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100493 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000494 pl, const = op["operands"]
495
496 if error_name != ErrorIf.WrongRank:
497 assert rank == 3
498 assert pl == 1 and const == 0
499
500 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100501 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000502
503 # Select nearest lower power of two from input height and width
504 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
505 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
506
507 # Constrict the overall size of the shape when creating ERROR_IF tests
508 if error_name:
509 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
510
511 # Generate an invalid kernel that is not a power of two
512 if error_name == ErrorIf.KernelNotPowerOfTwo:
513 # We must increment by 2 if current size is 1
514 inc_h = 2 if ifm_shape[1] == 1 else 1
515 inc_w = 2 if ifm_shape[2] == 1 else 1
516 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100517 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000518 ifm_shape[1] += selected_inc[0]
519 ifm_shape[2] += selected_inc[1]
520
James Ward30124a82023-02-02 14:56:33 +0000521 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000522
523 return [ifm_shape]
524
525 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527 pl, const = op["operands"]
528
529 if error_name != ErrorIf.WrongRank:
530 assert rank == 2
531
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100532 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
537
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100538 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100539 low=testGen.args.tensor_shape_range[0],
540 high=testGen.args.tensor_shape_range[1],
541 size=1,
542 )[0]
543 filter_shape = np.asarray([filter_oc, input_shape[1]])
544
545 bias_shape = np.asarray([filter_oc])
546
547 return [input_shape, filter_shape, bias_shape]
548
549 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100550 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100551 pl, const = op["operands"]
552
553 if error_name != ErrorIf.WrongRank:
554 assert rank == 3
555 assert pl == 2 and const == 0
556
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100557 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100558
559 # Constrict the overall size of the shape when creating ERROR_IF tests
560 if error_name:
561 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
562
563 # Get a random number for b_oc even if target shape is defined
564 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100565 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100566 low=testGen.args.tensor_shape_range[0],
567 high=testGen.args.tensor_shape_range[1],
568 size=1,
569 )
570 )[0]
571 # If N or H is large let b_oc be 1 to reduce output tensor size
572 if max(a_shape) > 1000:
573 b_oc = 1
574
575 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
576 return [a_shape, b_shape]
577
578 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100579 def tgConcat(testGen, rng, op, rank, error_name=None):
580 pl, const = op["operands"]
581 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100582
583 # Create extra tensors to concat.
584 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100585 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100586 shape_list = []
587 for i in range(pl + const + num_tensors):
588 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100589 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100590 wrongShape = shape.copy()
591
592 if remove and len(shape) > 1:
593 wrongShape = wrongShape[1:]
594 else:
595 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100596 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100597
598 shape_list.append(wrongShape)
599 else:
600 shape_list.append(shape.copy())
601
602 return shape_list
603
604 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100605 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100606 if error_name in [
607 ErrorIf.AxisSmallerZero,
608 ErrorIf.AxisLargerRank,
609 ErrorIf.ConcatInputRankMismatch,
610 ]:
611 return shapeList
612
613 # Split concat shape along axis to allow for multiple const inputs
614 # without making too many large tensors
615 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
616 # If axis can't be split we still need to invalidate other dimensions
617 if error_name == ErrorIf.ConcatInputDimMismatch:
618 for shape in shapeList[1:]:
619 # Negative test shapeLists are created individually for each test,
620 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100621 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100622 return shapeList
623
624 # Create copy of shape we are going to split (so we don't alter shapeList)
625 shape = shapeList[0].copy()
626 # Add original shape as first input
627 new_shapeList = [shape.copy()]
628 length_on_axis = shape[axis]
629 remaining_length = length_on_axis
630 for i in range(len(shapeList) - 2):
631 # Calculate split on axis and remaining value
632 split_shape_val = int(shape[axis] / 2)
633 remaining_length = remaining_length - split_shape_val
634
635 # Append new shape, and set remaining shape
636 shape[axis] = split_shape_val
637 new_shapeList.append(shape.copy())
638
639 # invalidate dimensions
640 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100641 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100642 else:
643 shape[axis] = remaining_length
644
645 if i == len(shapeList) - 3:
646 new_shapeList.append(shape.copy())
647
648 return new_shapeList
649
650
651class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100652 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100653
654 def __init__(self):
655 pass
656
Jeremy Johnson1271c442023-09-05 11:39:26 +0100657 class TVGInfo:
658 """Enhanced tensor values information including data gen dict."""
659
660 def __init__(self, tensorList, dataGenDict):
661 self.tensorList = tensorList
662 self.dataGenDict = dataGenDict
663
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100664 # Default high value for random numbers
665 TVG_FLOAT_HIGH_VALUE = {
666 DType.FP32: (1 << 128) - (1 << (127 - 23)),
667 DType.FP16: (1 << 16) - (1 << (15 - 10)),
668 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000669 DType.FP8E4M3: 448,
670 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100671 }
672
Jeremy Johnson30476252023-11-20 16:15:30 +0000673 # Default lowest normal values for random numbers
674 TVG_FLOAT_LOW_VALUE = {
675 DType.FP32: np.exp2(-126),
676 DType.FP16: np.exp2(-14),
677 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000678 DType.FP8E4M3: np.exp2(-9),
679 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000680 }
681
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100682 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100683 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000684 # Return a tuple of (low,high) data range values for the given data
685 # type using a combination of per operator table limits, data limits
686 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000687 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100688 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000689 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000690 if lowValueLookup is not None and dtype in lowValueLookup:
691 low_val = lowValueLookup[dtype]
692 else:
693 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000694 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000695 # respecting the default ranges if more/less than the low/high
696 # values
697 data_range = (
698 max(low_val, type_range[0]),
699 min(high_val, type_range[1]),
700 )
701 if data_range[0] > data_range[1]:
702 # Invalid data range from low to high created due to user
703 # constraints revert to using internal ranges as they are
704 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000705 logger.info(
706 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
707 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000708 data_range = (low_val, high_val)
709 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000710 return None
711
712 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100713 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100714 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100715 ):
716 # Variable inputs versus constants
717 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000718 if "p_count" in argsDict:
719 # Override for operators like CONCAT
720 pCount = argsDict["p_count"]
721 cCount = argsDict["c_count"]
722 assert pCount + cCount == len(
723 shapeList
724 ), "Placeholders & Constant tensors must match shapes list"
725
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000726 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100727
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100728 if (
729 error_name is not None
730 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100731 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100732 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000733 # Fall back to internal data gen when dealing with unsupported types or ops
734 data_range = argsDict["data_range"] if "data_range" in argsDict else None
735 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000736 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000737 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000738 if "data_range_list" in argsDict:
739 data_range = argsDict["data_range_list"][idx]["range"]
740 roundMode = (
741 "round" in argsDict["data_range_list"][idx]
742 and argsDict["data_range_list"][idx]["round"] is True
743 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000744 if data_range is not None and dtype not in (
745 DType.FP16,
746 DType.FP32,
747 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000748 DType.FP8E4M3,
749 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000750 ):
751 # Change from inclusive to exclusive range
752 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000753
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100754 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000755 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000756 if dtype == DType.SHAPE:
757 arr = np.int64(argsDict["fixed_data"][idx])
758 elif dtype == DType.INT8:
759 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000760 elif dtype == DType.INT16:
761 arr = np.int16(argsDict["fixed_data"][idx])
762 elif dtype == DType.INT32:
763 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000764 else:
765 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000766 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100767 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000768 if roundMode:
769 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000770 if idx < pCount:
771 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
772 else:
773 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100774
Jeremy Johnson1271c442023-09-05 11:39:26 +0100775 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
776
777 # Create data generator meta-data
778 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100779 tens_data = {
780 "version": "0.1",
781 "tensors": {},
782 }
783 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100784 for idx, shape in enumerate(shapeList):
785
786 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000787 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
788 tens_meta["generator"] = gtu.DataGenType(
789 gtu.DataGenType.FIXED_DATA
790 ).name
791 else:
792 tens_meta["generator"] = gtu.DataGenType(dg_type).name
793
Jeremy Johnson1271c442023-09-05 11:39:26 +0100794 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
795 tens_meta["shape"] = [int(i) for i in shape]
796 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100797 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100798
799 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100800 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100801 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100802 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100803
804 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
805 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000806 if (
807 tens_meta["generator"]
808 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
809 ):
810 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
811 tens_meta["fixed_data_info"] = info
812 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100813 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000814
Won Jeon64e4bfe2024-01-18 06:31:55 +0000815 data_range = None
816 if "data_range_list" in argsDict:
817 data_range = argsDict["data_range_list"][idx]["range"]
818 if "round" in argsDict["data_range_list"][idx]:
819 info["round"] = argsDict["data_range_list"][idx]["round"]
820 elif "data_range" in argsDict:
821 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000822
Won Jeon64e4bfe2024-01-18 06:31:55 +0000823 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100824 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000825 info["range"] = [str(v) for v in data_range]
826 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100827 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
828 info = {}
829 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100830 info["ks"] = int(argsDict["ks"])
831 if "acc_type" in argsDict:
832 # Convert type number into JSON name
833 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
834 "json"
835 ]
836 if "kernel" in argsDict:
837 info["kernel"] = [int(k) for k in argsDict["kernel"]]
838 if "axis" in argsDict:
839 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100840 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000841 elif dg_type == gtu.DataGenType.FULL_RANGE:
842 info = {}
843 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100844 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000845 )
846 tens_meta["full_range_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100847 else:
848 # TODO - other data gen type
849 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850
851 # Using the finished generate config meta data - generate the data if
852 # needed and assign a tensor name from the serializer
853
854 # Need to generate data when not lazy or for the bias tensor as we need
855 # to work out if the bias data is non-zero for compliance
856 if not testGen.args.lazy_data_gen or (
857 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
858 ):
859 # Give this tensor a temporary name until we get one from the serializer
860 temp_name = f"placeholder_{idx}"
861 dg_tens_meta[temp_name] = tens_meta
862 # Create data now using the temporary name to access meta details
863 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000864 if tens_meta["data_type"] == "SHAPE":
865 # Tensor type SHAPE and Numpy file type must be the same
866 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100867 # Remove the item as we will give it the correct name later
868 del dg_tens_meta[temp_name]
869
870 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
871 # The KS value used by compliance verification is altered when the
872 # bias data is non-zero
873 if max(abs(data)) > 0.0:
874 argsDict["ksb"] = argsDict["ks"] + 1
875
876 if testGen.args.lazy_data_gen:
877 data = None
878
879 if tens_meta["input_type"] == "VARIABLE":
880 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
881 else:
882 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
883
884 tens_ser_list.append(tens)
885 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100886 dg_tens_meta[tens.name] = tens_meta
887
Jeremy Johnson1271c442023-09-05 11:39:26 +0100888 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
889
890 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100891 def tvgNegate(
892 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
893 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100894 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000895 # Integer test
896 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100897 pCount, cCount = op["operands"]
898 assert (
899 pCount == 1 and cCount == 0
900 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100901 # Must create tensors with values within accumulator (int32) negatable
902 # range
903 max_val = (1 << 31) - 1
904 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100905 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100906 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000908 tens_ser_list = []
909 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100910 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
911 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000912 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100913 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000914 # ERROR_IF or floating point test
915 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100916 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100917 )
918
Jeremy Johnson30476252023-11-20 16:15:30 +0000919 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000920 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
921 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
922 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
923 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
924 }
925
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100927 def tvgAddSub(
928 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
929 ):
Won Jeon74342e52024-01-09 00:34:40 +0000930 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000931 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100932 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000933 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100934 pCount, cCount = op["operands"]
935 assert (
936 pCount == 2 and cCount == 0
937 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000938 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000939 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
Jeremy Johnson32bf9012024-03-20 16:32:23 +0000940 data_range = None # Use default
941 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE):
942 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100943 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
944 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100945 if add:
946 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
947 else:
948 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
949
950 # Work out the saturation limits
951 max_i32 = (1 << 31) - 1
952 min_i32 = -(1 << 31)
953 max_arr = np.full(shapeList[1], max_i32)
954 min_arr = np.full(shapeList[1], min_i32)
955
956 # Find how much values exceed the maximum/minimums
957 sat_max_arr = np.maximum(res_arr - max_arr, 0)
958 sat_min_arr = np.minimum(res_arr - min_arr, 0)
959
960 if not add:
961 # Swap saturation values and negate values as we need to perform opposite operations
962 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
963
964 # Create new array of unsaturated values by clipping values as needed
965 b_unsat_arr = b_arr
966 if (sat_max_arr != 0).any():
967 # Clip values that cause saturation
968 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
969 # Reduce axes in unsaturated tensor to match original tensor
970 for axis, dim in enumerate(b_arr.shape):
971 if dim != b_unsat_arr.shape[axis]:
972 assert (
973 dim == 1
974 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
975 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
976
977 if (sat_min_arr != 0).any():
978 # Clip values that cause saturation
979 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
980 # Reduce axes in unsaturated tensor to match original tensor
981 for axis, dim in enumerate(b_arr.shape):
982 if dim != b_unsat_arr.shape[axis]:
983 assert (
984 dim == 1
985 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
986 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
987
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000988 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100989 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
990 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000991 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100992 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
993 )
994
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000995 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100996 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000997 # ERROR_IF or floating point test
998 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100999 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001000 )
1001 if data_range:
1002 argsDict["data_range"] = data_range
1003
1004 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001005 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001006 )
1007
1008 @staticmethod
1009 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001010 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001011 ):
1012 if dtypeList[0] in (
1013 DType.INT32,
1014 DType.INT16,
1015 DType.INT8,
1016 ):
1017 # Limit input tensors with cond_if_binary or while_loop to stop
1018 # saturation of add/sub ops with int32 and keep all logical shift
1019 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001020 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 pCount, cCount = op["operands"]
1022 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001023 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001024 for idx, shape in enumerate(shapeList[:]):
1025 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001026 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001028 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001029 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001030 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001031 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1032 )
1033 pRemain -= 1
1034 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001035 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036 testGen.ser.addConst(shape, dtypeList[idx], arr)
1037 )
1038
Jeremy Johnson587cc842024-02-08 11:45:44 +00001039 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001040 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001041 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001042 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001043 )
1044
1045 @staticmethod
1046 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001047 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001048 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001049 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001050 pCount, cCount = op["operands"]
1051 # Force value of operand[1] to be within [0, num_bits]
1052 assert (
1053 pCount == 2 and cCount == 0
1054 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1055
Jeremy Johnson587cc842024-02-08 11:45:44 +00001056 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 for idx, shape in enumerate(shapeList[:]):
1058 if idx == 1:
1059 if dtypeList[idx] == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001060 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001061 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001062 arr = np.int32(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001063 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001064 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001065 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001066 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001067 else:
1068 raise Exception("OpArithmeticRightShift: invalid input dtype")
1069 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001070 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001071 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001072
Jeremy Johnson587cc842024-02-08 11:45:44 +00001073 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001074
1075 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001076 def tvgReshape(
1077 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1078 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001079 dtypeList[1] = DType.SHAPE
1080 shapeList[1] = [len(argsDict["new_shape"])]
1081 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1082 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1083
1084 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001085 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001086 )
1087
1088 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001089 def tvgRescale(
1090 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1091 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001092 scale32 = argsDict["scale"]
1093 multiplier_arr = argsDict["multiplier"]
1094 shift_arr = argsDict["shift"]
1095
1096 if scale32:
1097 dtypeList[1] = DType.INT32
1098 else:
1099 dtypeList[1] = DType.INT16
1100 shapeList[1] = [len(multiplier_arr)]
1101 dtypeList[2] = DType.INT8
1102 shapeList[2] = [len(shift_arr)]
1103 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1104 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1105
1106 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001107 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001108 )
1109
1110 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001111 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001112 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1113 pad_values = argsDict["pad"].flatten()
1114 dtypeList[1] = DType.SHAPE
1115 shapeList[1] = [len(pad_values)]
1116 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1117 argsDict["fixed_data"] = [None, pad_values]
1118
1119 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001120 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001121 )
1122
1123 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001124 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001125 dtypeList[1] = DType.SHAPE
1126 shapeList[1] = [len(argsDict["start"])]
1127 dtypeList[2] = DType.SHAPE
1128 shapeList[2] = [len(argsDict["size"])]
1129 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1130 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1131
1132 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001133 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001134 )
1135
1136 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001137 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001138 dtypeList[1] = DType.SHAPE
1139 shapeList[1] = [len(argsDict["multiples"])]
1140 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1141
1142 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001143 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001144 )
1145
1146 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001147 def tvgSelect(
1148 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1149 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001150 # Set datatype of condition tensor to boolean
1151 dtypeList[0] = DType.BOOL
1152
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001153 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001154 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 )
1156
1157 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001158 def tvgIntDiv(
1159 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1160 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001162 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001163 pCount, cCount = op["operands"]
1164 assert (
1165 pCount == 2 and cCount == 0
1166 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1167
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001168 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001169
1170 # Two invalid cases for Op.INTDIV:
1171 # 1. divisor == 0
1172 # 2. dividend == -(1<<31) and divisor == -1
1173 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001174 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1175 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001176
1177 if (divisor_arr == 0).any():
1178 continue
1179
1180 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1181 continue
1182
1183 break
1184
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001185 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001186 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1187 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001188 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001189 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1190 )
1191
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001192 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001193 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001194 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001195 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001196 )
1197
Jeremy Johnson30476252023-11-20 16:15:30 +00001198 # Set the MUL data range to the square root of the largest value
1199 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001200 TVG_FLOAT_HIGH_VALUE_MUL = {
1201 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1202 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1203 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1204 }
1205
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001206 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001207 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001208 if error_name is not None or dtypeList[0] in (
1209 DType.FP16,
1210 DType.BF16,
1211 DType.FP32,
1212 ):
1213 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001214 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001215 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001216 )
1217 if data_range:
1218 argsDict["data_range"] = data_range
1219
Jeremy Johnson0a042992024-02-28 13:20:05 +00001220 if dtypeList[0] != DType.SHAPE:
1221 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1222 dtypeList[2] = DType.INT8
1223 shapeList[2] = [1]
1224 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1225 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1226
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001227 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001228 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001229 )
1230 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001231 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001232 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001233
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001234 tens_ser_list = []
1235
1236 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001237 if dtypeList[0] == DType.SHAPE:
1238 shift = 0
1239 else:
1240 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001241 if dtypeList[0] == DType.INT8:
1242 num_bits = 8
1243 elif dtypeList[0] == DType.INT16:
1244 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001245 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001246 num_bits = 32
1247 elif error_name == ErrorIf.WrongInputType:
1248 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001249 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001250 raise Exception(
1251 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1252 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001253
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001254 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001255 if dtypeList[idx] == DType.SHAPE:
1256 low = testGen.args.tensor_shape_range[0]
1257 high = testGen.args.tensor_shape_range[1]
1258 else:
1259 low = -(2 ** (num_bits - 1))
1260 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001261
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001262 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1263 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001264
1265 i = 0
1266 while True:
1267
1268 a_arr_64 = a_arr.astype(np.int64)
1269 b_arr_64 = b_arr.astype(np.int64)
1270
1271 if shift > 0:
1272 rounding = 1 << (shift - 1)
1273 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001274 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001275 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001276
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001277 if (result_arr > -(2**31)).all() and (
1278 result_arr <= ((2**31) - 1)
1279 ).all():
1280 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001281
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001282 i = i + 1
1283 a_arr = a_arr // 2
1284 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285
Won Jeon74342e52024-01-09 00:34:40 +00001286 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001287 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001288 tens_ser_list.append(
1289 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1290 )
1291 tens_ser_list.append(
1292 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1293 )
1294 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001295 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001296 tens_ser_list.append(
1297 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1298 )
1299 tens_ser_list.append(
1300 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1301 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001302 tens_ser_list.append(
1303 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1304 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001305
1306 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001307
1308 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001309 def tvgConcat(
1310 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1311 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001312 count = len(shapeList) - testGen.args.num_const_inputs_concat
1313 if count < 1:
1314 count = 1
1315 if testGen.args.num_const_inputs_concat == 0:
1316 count = len(shapeList)
1317
Won Jeon74342e52024-01-09 00:34:40 +00001318 op = testGen.TOSA_OP_LIST[opName]
1319 if op["op"] == Op.CONCAT_SHAPE:
1320 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001321 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001322 else:
1323 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001324 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001325 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001326
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001327 # Override default pCount/cCount for operator
1328 argsDict["p_count"] = count
1329 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001330
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001331 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001332 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001333 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001334
1335 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001336 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001337 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001338 ):
1339 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001340 pCount, cCount = op["operands"]
1341 assert (
1342 pCount == 2 and cCount == 0
1343 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001344 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1345 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001346 tens_ser_list = []
1347 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001348 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1349 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001350 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001351 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1352 )
1353
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001354 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001355
1356 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001357 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001358 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1359 # Integer
1360 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001361 pCount, cCount = op["operands"]
1362 assert (
1363 pCount == 2 and cCount == 0
1364 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001365
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001366 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1367 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001368
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001369 # Using random numbers means that it will be very unlikely that
1370 # there are any matching (equal) values, therefore force that
1371 # there are twice the number of matching values as the tensor rank
1372 for num in range(0, len(shapeList[0]) * 2):
1373 a_index = []
1374 b_index = []
1375 # Choose an index in each axis for the whole shape
1376 for axis in range(0, len(shapeList[0])):
1377 # Index can be up to the largest dimension in both shapes
1378 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001379 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001380 )
1381 # Reduce the index down to a shape's dim for broadcasting
1382 a_index.append(min(shapeList[0][axis] - 1, index))
1383 b_index.append(min(shapeList[1][axis] - 1, index))
1384
1385 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1386
Jeremy Johnsona0150012023-11-15 15:52:06 +00001387 tens_ser_list = []
1388 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001389 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1390 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001391 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001392 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1393 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001394 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001395 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001396 # ERROR_IF or floating point test
1397 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001398 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001399 )
1400
1401 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001402 def tvgReduceSum(
1403 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1404 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001405 dtype = dtypeList[0]
1406 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001407 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001408 pCount, cCount = op["operands"]
1409 assert (
1410 pCount == 1 and cCount == 0
1411 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1412 # Limit values so that the sum cannot exceed the range of an int32 during
1413 # summation of any axis
1414 range_val = int((1 << 31) / max(shapeList[0]))
1415 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001416 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001417 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001418 tens_ser_list = []
1419 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001420 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001421 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001422 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001423 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001424 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001425 if (
1426 error_name is None
1427 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1428 ):
1429 # Limit ranges for (non error & non compliance) tests by using
1430 # values that can be summed on any axis to not hit infinity
1431 highval_lookup = {
1432 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1433 / max(shapeList[0])
1434 }
1435 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001436 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001437 )
1438 assert data_range is not None
1439 argsDict["data_range"] = data_range
1440
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001441 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001442 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001443 )
1444
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001445 @staticmethod
1446 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001447 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001448 ):
1449 dtype = dtypeList[0]
1450 if error_name is None:
1451 # Limit ranges for (non error) tests by using
1452 # values that can be multiplied on any axis to not hit infinity
1453 highval_lookup = {
1454 dtype: math.pow(
1455 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1456 1 / max(shapeList[0]),
1457 )
1458 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001459 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001460 assert data_range is not None
1461 argsDict["data_range"] = data_range
1462
1463 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001464 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001465 )
1466
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001467 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001468 def tvgResize(
1469 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1470 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001471 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001472 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001473 dtypeList[0],
1474 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1475 )
1476 if data_range:
1477 argsDict["data_range"] = data_range
1478 # Needed for compliance
1479 argsDict["max_abs_value"] = data_range[1]
1480
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001481 scale_values = argsDict["scale"]
1482 offset_values = argsDict["offset"]
1483 border_values = argsDict["border"]
1484 dtypeList[1] = DType.SHAPE
1485 dtypeList[2] = DType.SHAPE
1486 dtypeList[3] = DType.SHAPE
1487 shapeList[1] = [len(scale_values)]
1488 shapeList[2] = [len(offset_values)]
1489 shapeList[3] = [len(border_values)]
1490 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1491
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001492 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001493 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001494 )
1495
Jeremy Johnson30476252023-11-20 16:15:30 +00001496 # Set the POW exponent high data range
1497 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1498 DType.FP32: 10.0,
1499 DType.FP16: 10.0,
1500 DType.BF16: 10.0,
1501 }
1502 # POW highest base value (within a safe margin of error) that can be raised
1503 # to +ve exponent that doesn't become Infinity
1504 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1505 DType.FP32: math.floor(
1506 math.pow(
1507 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1508 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1509 )
1510 ),
1511 DType.FP16: math.floor(
1512 math.pow(
1513 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1514 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1515 )
1516 ),
1517 DType.BF16: math.floor(
1518 math.pow(
1519 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1520 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1521 )
1522 ),
1523 }
1524 # POW lowest base value (within a safe margin of error) that can be raised
1525 # to -ve exponent that doesn't become Infinity
1526 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1527 DType.FP32: math.ceil(
1528 math.pow(
1529 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1530 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1531 )
1532 * 1000
1533 )
1534 / 1000,
1535 DType.FP16: math.ceil(
1536 math.pow(
1537 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1538 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1539 )
1540 * 1000
1541 )
1542 / 1000,
1543 DType.BF16: math.ceil(
1544 math.pow(
1545 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1546 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1547 )
1548 * 1000
1549 )
1550 / 1000,
1551 }
1552
1553 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001554 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001555 if error_name is not None:
1556 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001557 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001558 )
1559 dtype = dtypeList[0]
1560 # Different ranges for POW
1561 test_set = argsDict["s"]
1562 if test_set == 0:
1563 # Positive base with fractional exponent
1564 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001565 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001566 dtype,
1567 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1568 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1569 )
1570 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001571 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001572 )
1573 exp_round = False
1574 else:
1575 # Integer exponent
1576 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001577 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001578 )
1579 exp_round = True
1580 if test_set == 1:
1581 # Positive base
1582 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001583 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001584 dtype,
1585 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1586 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1587 )
1588 else:
1589 assert test_set == 2
1590 # Negative base
1591 # Supply new look up tables with negative values
1592 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001593 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001594 dtype,
1595 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1596 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1597 )
1598
1599 data_range_list = (
1600 {
1601 "range": base_range,
1602 },
1603 {
1604 "range": exp_range,
1605 "round": exp_round,
1606 },
1607 )
1608 argsDict["data_range_list"] = data_range_list
1609 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001610 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001611 )
1612
1613 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001614 def tvgLogRsqrt(
1615 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1616 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001617 # LOG & RSQRT data range from lowest expressible positive number to
1618 # largest to avoid NaNs
1619 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001620 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001621 dtypeList[0],
1622 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1623 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1624 )
1625 if data_range:
1626 argsDict["data_range"] = data_range
1627
1628 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001629 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001630 )
1631
1632 # Set the EXP data range to the log of the largest to smallest values
1633 # to avoid infinities or making the result zero
1634 TVG_FLOAT_HIGH_VALUE_EXP = {
1635 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1636 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1637 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1638 }
1639 TVG_FLOAT_LOW_VALUE_EXP = {
1640 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1641 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1642 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1643 }
1644
1645 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001646 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001647 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001648 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001649 dtypeList[0],
1650 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1651 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1652 )
1653 if data_range:
1654 argsDict["data_range"] = data_range
1655
1656 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001657 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001658 )
1659
1660 @staticmethod
1661 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001662 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001663 ):
1664 dtype = dtypeList[0]
1665 if (
1666 error_name is None
1667 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001668 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001669 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001670 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001671 # Limit ranges for (non error & non compliance) FP tests by using
1672 # values that can be multiplied on any axis to not hit infinity/NaN
1673 IC = shapeList[0][1]
1674 highval_lookup = {
1675 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1676 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001677 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001678 assert data_range is not None
1679 argsDict["data_range"] = data_range
1680
1681 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001682 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001683 )
1684
Jeremy Johnson708da822023-11-15 16:25:45 +00001685 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001686 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001687 in_dtype = dtypeList[0]
1688 out_dtype = argsDict["out_type"]
1689 # Create look up to limit input tensor to output type maximums to avoid
1690 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001691 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001692 highval_lookup = {in_dtype: out_range[1]}
1693 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001694 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001695 in_dtype,
1696 highval_lookup,
1697 )
1698
1699 assert data_range is not None
1700 argsDict["data_range"] = data_range
1701
1702 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001703 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001704 )
1705
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001706 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001707 def tvgGather(
1708 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1709 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001710 K = shapeList[0][1]
1711
1712 # Fix the type of the indices tensor
1713 dtypeList[1] = DType.INT32
1714
1715 dtype = dtypeList[0]
1716 if not gtu.dtypeIsSupportedByCompliance(dtype):
1717 # Test unsupported by data generator
1718 op = testGen.TOSA_OP_LIST[opName]
1719 pCount, cCount = op["operands"]
1720 assert (
1721 pCount == 2 and cCount == 0
1722 ), "Op.GATHER must have 2 placeholders, 0 consts"
1723
1724 tens_ser_list = []
1725 for idx, shape in enumerate(shapeList):
1726 dtype = dtypeList[idx]
1727 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001728 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001729 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1730 else:
1731 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001732 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001733 # To match old functionality - create indices as CONST
1734 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1735
1736 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1737
1738 else:
1739 # ERROR_IF or floating point test
1740 # Use inclusive values upto index K for indices tensor
1741 data_range_list = (
1742 {"range": None},
1743 {"range": (0, K - 1)},
1744 )
1745 argsDict["data_range_list"] = data_range_list
1746
1747 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001748 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001749 )
1750
1751 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001752 def tvgScatter(
1753 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1754 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001755 K = shapeList[0][1]
1756 W = shapeList[2][1]
1757
1758 # Work out an indices tensor here with data that doesn't exceed the
1759 # dimension K of the values_in tensor and does NOT repeat the same K
1760 # location as needed by the spec:
1761 # "It is not permitted to repeat the same output index within a single
1762 # SCATTER operation and so each output index occurs at most once."
1763 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1764
1765 # Fix the type of the indices tensor
1766 dtypeList[1] = DType.INT32
1767
1768 dtype = dtypeList[0]
1769 if not gtu.dtypeIsSupportedByCompliance(dtype):
1770 # Test unsupported by data generator
1771 op = testGen.TOSA_OP_LIST[opName]
1772 pCount, cCount = op["operands"]
1773 assert (
1774 pCount == 3 and cCount == 0
1775 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1776
1777 tens_ser_list = []
1778 for idx, shape in enumerate(shapeList):
1779 dtype = dtypeList[idx]
1780 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001781 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001782 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1783 else:
1784 # Create the indices array
1785 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1786 arr = []
1787 for n in range(shape[0]):
1788 # Get a shuffled list of output indices (0 to K-1) and
1789 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001790 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001791 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1792 # To match old functionality - create indices as CONST
1793 tens_ser_list.append(
1794 testGen.ser.addConst(shape, dtype, indices_arr)
1795 )
1796
1797 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1798
1799 else:
1800 # ERROR_IF or floating point test
1801 # Use inclusive values upto index K for indices tensor
1802 data_range_list = (
1803 {"range": None},
1804 {"range": (0, K - 1)},
1805 {"range": None},
1806 )
1807 argsDict["data_range_list"] = data_range_list
1808
1809 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001810 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001811 )
1812
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001813
1814class TosaArgGen:
1815 """Argument generators create exhaustive or random lists of attributes for
1816 operators that take attributes or other parameters.
1817
1818 The return value is a list of (descriptive_name, [arglist]) tuples where
1819 the descriptive_name is appended to the test name and the arglist is expanded
1820 as arguments to the operator build function.
1821 """
1822
1823 def __init__(self):
1824 pass
1825
1826 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001827 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001828 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001829 if (
1830 error_name is None
1831 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1832 and gtu.dtypeIsSupportedByCompliance(dtype)
1833 ):
evacha01ad8e1e22024-03-19 12:42:17 +00001834 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"].get(
1835 dtype, (gtu.DataGenType.PSEUDO_RANDOM,)
1836 )
1837
Jeremy Johnson1271c442023-09-05 11:39:26 +01001838 else:
1839 # Error test or No data generator types listed - assume random
1840 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1841
1842 # Expand arg list with other data generator types
1843 new_arg_list = []
1844 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001845 for arg_str, args_dict in arg_list:
evacha01ad8e1e22024-03-19 12:42:17 +00001846 gen_args_dict = args_dict.copy()
Jeremy Johnson1271c442023-09-05 11:39:26 +01001847 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001848 if error_name is None:
1849 num_test_sets = (
1850 args_dict["num_test_sets"]
1851 if "num_test_sets" in args_dict
1852 else 0
1853 )
1854 else:
evacha019c96eef2024-02-07 11:21:55 +00001855 # Add single test for pseudo random
Jeremy Johnson30476252023-11-20 16:15:30 +00001856 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001857
1858 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1859 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001860 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001861 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001862 shape_info = (
1863 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1864 if "shape" in args_dict
1865 else ""
1866 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001867 logger.info(
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00001868 f"Skipping {opName}{shape_info} {gtu.DTYPE_ATTRIBUTES[dtype]['json']} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001869 )
1870 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001871 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001872 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001873 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001874
Jeremy Johnson30476252023-11-20 16:15:30 +00001875 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1876
evacha01ad8e1e22024-03-19 12:42:17 +00001877 elif dg_type == gtu.DataGenType.FULL_RANGE:
1878 tensor_size = gtu.product(shapeList[0])
1879 if tensor_size < gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1880 shape_info = " ({})".format(shapeList[0])
1881 logger.info(
1882 f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}"
1883 )
1884 continue
1885 # Large enough tensor data size for full range, add a single test
1886 num_test_sets = 0
1887 arg_str = f"{arg_str}_full" if arg_str else "full"
1888 gen_args_dict["tags"] = args_dict.get("tags", []) + [
1889 "non_finite_fp_data"
1890 ]
1891
1892 gen_args_dict["dg_type"] = dg_type
Jeremy Johnson30476252023-11-20 16:15:30 +00001893 if num_test_sets > 0:
1894 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001895 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
evacha01ad8e1e22024-03-19 12:42:17 +00001896 set_args_dict = gen_args_dict.copy()
evacha019c96eef2024-02-07 11:21:55 +00001897 set_args_dict["s"] = s
evacha019c96eef2024-02-07 11:21:55 +00001898 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001899 else:
1900 # Default is a single test
evacha01ad8e1e22024-03-19 12:42:17 +00001901 new_arg_list.append((arg_str, gen_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001902
1903 return new_arg_list
1904
1905 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001906 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001907 """A trivial argument generator for operators that don't take any
1908 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001909 arg_list = TosaArgGen._add_data_generators(
1910 testGen,
1911 opName,
evacha019c96eef2024-02-07 11:21:55 +00001912 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001913 dtype,
1914 [("", {})],
1915 error_name,
1916 )
1917 # Return list of tuples: (arg_str, args_dict)
1918 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001919
1920 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001921 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001922 """Pow operator needs different test sets to cover random numbers
1923 without creating NaNs or Infs"""
1924 arg_list = TosaArgGen._add_data_generators(
1925 testGen,
1926 opName,
evacha019c96eef2024-02-07 11:21:55 +00001927 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001928 dtype,
1929 [("", {"num_test_sets": 3})],
1930 error_name,
1931 )
1932 # Return list of tuples: (arg_str, args_dict)
1933 return arg_list
1934
1935 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001936 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001937 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001938 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001939 shape = shapeList[0]
1940
1941 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001942 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001943 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001944 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001945 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001946 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001947 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001948 # Create tests for each dimension
1949 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001950
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001951 opid = testGen.TOSA_OP_LIST[opName]["op"]
1952
1953 for a in axes:
1954 args_dict = {"axis": int(a)}
1955 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001956 output_shape = shape.copy()
1957 if error_name is None:
1958 # It only matters that we calculate the dot_products correctly
1959 # for non error_if tests as they should never be run
1960 output_shape[a] = 1
1961 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001962 args_dict["shape"] = shape
1963 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1964 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1965
1966 arg_list.append(("axis{}".format(a), args_dict))
1967
1968 arg_list = TosaArgGen._add_data_generators(
1969 testGen,
1970 opName,
evacha019c96eef2024-02-07 11:21:55 +00001971 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001972 dtype,
1973 arg_list,
1974 error_name,
1975 )
1976 # Return list of tuples: (arg_str, args_dict)
1977 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001978
1979 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001980 def _calculate_sparsity(num_tests, sparsity_factor):
1981 sparsity = num_tests // sparsity_factor + 1
1982 # If there are only a small number of tests, just select them all
1983 if sparsity < 13:
1984 sparsity = 1
1985 # To get a variety of parameter combinations sparsity should not be a
1986 # multiple of 2, 3 or 5
1987 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1988 sparsity += 1
1989 return sparsity
1990
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001991 # Maximum number of error_if variants to produce
Jeremy Johnson87460262024-03-25 09:46:02 +00001992 MAX_TESTS_ERROR_IFS = 3
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001993
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001994 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001995 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001996 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001997 arg_list = []
1998
Jeremy Johnson0c716862023-04-13 17:18:19 +01001999 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002000 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002001 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002002 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002003
Tai Lyf36f2562024-03-14 16:21:29 +00002004 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2005
2006 if error_name == ErrorIf.WrongAccumulatorType:
2007 accum_dtypes = (
2008 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2009 )
James Ward8b390432022-08-12 20:48:56 +01002010
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002011 # For op type checks
2012 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002013
2014 # Check the rank
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002015 rank = 5 if op["op"] == Op.CONV3D else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002016 if error_name != ErrorIf.WrongRank:
2017 assert len(ifm_shape) == rank
2018 assert len(filter_shape) == rank
2019
Jeremy Johnson0c716862023-04-13 17:18:19 +01002020 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002021 k_rank = rank - 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002022 k_pos = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002023 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002024 # compliance size - KS
2025 k_size = gtu.product(k_shape)
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002026 if not op["op"] == Op.DEPTHWISE_CONV2D:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002027 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002028
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002029 def get_conv_output_info(p, s, d, fix_up_padding=False):
2030 # Work out remainders and output dimensions with an
2031 # option to adjust paddings to create a valid operation
2032 nonlocal ifm_shape, k_shape, error_name, k_rank
2033 if fix_up_padding:
2034 p = list(p) # Make paddings editable
2035 outputs_no_stride = []
2036 remainders = []
2037 outputs = []
2038 for index in range(k_rank):
2039 pad_offset = index * 2
2040 fixed = False
2041 # Fix up pad values to produce valid conv2d
2042 while not fixed:
2043 # Output dimension without being adjusted for stride
2044 output_no_stride = (
2045 ifm_shape[index + 1]
2046 - 1
2047 + p[pad_offset]
2048 + p[pad_offset + 1]
2049 - (k_shape[index] - 1) * d[index]
2050 )
2051 # Tensor left over after applying striding
2052 remainder = output_no_stride % s[index]
2053 if not fix_up_padding:
2054 # Just want remainders and outputs
2055 break
2056 if output_no_stride <= 0:
2057 p[pad_offset + 1] += abs(output_no_stride) + 1
2058 continue
2059 if error_name == ErrorIf.ConvOutputShapeNonInteger:
2060 if remainder:
2061 # Conditions to trigger the test
2062 fixed = True
2063 else:
2064 p[pad_offset + 1] += 1
2065 else:
2066 if remainder:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002067 # Stride will be negative for StrideSmallerOne
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002068 assert remainder > 0 or (
2069 error_name == ErrorIf.StrideSmallerOne and remainder < 0
2070 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002071 p[pad_offset + 1] += abs(remainder)
2072 else:
2073 fixed = True
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002074 outputs_no_stride.append(output_no_stride)
2075 remainders.append(remainder)
2076 # Output dimension taking in to account stride
2077 outputs.append((output_no_stride // s[index]) + 1)
2078
2079 if fix_up_padding:
2080 p = tuple(p) # Make the paddings read-only
2081 assert min(outputs_no_stride) > 0, "Fix up did not work!"
2082 return p, remainders, outputs, outputs_no_stride
2083
2084 # Only fix up padding for conv2d and float types currently
2085 fix_up_padding = gtu.dtypeIsFloat(dtypes[0]) and op["op"] == Op.CONV2D
2086 # Allow any size of output dimension
2087 max_dim_size = None
2088 # Include all tests by default
2089 sparsity = 1
2090
2091 # Work out padding, strides and dilation ranges depending on
2092 # error and arguments
2093 if error_name in (
2094 ErrorIf.PadSmallerZero,
2095 ErrorIf.StrideSmallerOne,
2096 ErrorIf.DilationSmallerOne,
2097 ):
2098 # Use specific invalid value(s)
2099 if error_name == ErrorIf.PadSmallerZero:
2100 # Create negative paddings but with positive opposite paddings
2101 neg_pad = rng.choice(range(-5, 0))
2102 p_vals = [neg_pad, abs(neg_pad)]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002103 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002104 p_vals = [0, 0]
2105 if error_name == ErrorIf.StrideSmallerOne:
2106 # Can't use stride=0, as it is used to derive output shape, as a divisor
2107 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002108 else:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002109 s_vals = [1]
2110 if error_name == ErrorIf.DilationSmallerOne:
2111 d_vals = [rng.choice(range(-5, 1))]
2112 else:
2113 d_vals = [1]
2114 paddings = {tuple(p_vals) * k_rank}
2115 strides = {tuple(s_vals) * k_rank}
2116 dilations = {tuple(d_vals) * k_rank}
2117
2118 fix_up_padding = True # Need to fix up paddings to be valid
2119
2120 elif testGen.args.level8k and error_name is None:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002121 # Only test 8k levels boundaries
2122 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2123 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2124 bigPadding = bigKernel
2125
2126 dilation_shape = [1] * k_rank
2127 pad_shape = [0] * k_rank * 2
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002128 if op["op"] == Op.CONV3D:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002129 # Small stride apart from for big kernel (see below) to keep
2130 # tensor size/calculation small
2131 stride_shape = [1] * k_rank
2132 for idx in range(k_rank):
2133 pad_offset = idx * 2
2134 if k_shape[idx] == bigKernel:
2135 # Padding shape needs to account for tensor shape
2136 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2137 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2138 # Big stride to reduce output size
2139 stride_shape[idx] = bigKernel
2140 else:
2141 # Account for kernel size
2142 pad_shape[pad_offset] = k_shape[idx] - 1
2143 else:
2144 # Always have a large stride with extra padding and dilation to keep
2145 # tensor calculation reasonable
2146 stride_shape = [bigKernel] * k_rank
2147 for idx in range(k_rank):
2148 # Dilation shape must account for kernel size
2149 dilation_shape[idx] = bigKernel // k_shape[idx]
2150 # Padding shape needs to accommodate tensor/kernel & dilation
2151 pad_offset = idx * 2
2152 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2153 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2154
2155 strides = {tuple(stride_shape)}
2156 dilations = {tuple(dilation_shape)}
2157 paddings = {tuple(pad_shape)}
2158 # Create a limit for the output dimensions size
2159 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2160
2161 # Currently allow all combinations that are reasonable size
2162 sparsity = 1
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002163 else:
2164 # Generate comprehensive argument lists
2165 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
2166 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
2167 # Stride must be greater than 1 to force non-integer error
2168 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
2169 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
2170 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002171
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002172 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2173 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
2174
2175 if error_name is None and testGen.args.oversize:
2176 # add some oversize argument values
2177 if max(ifm_shape) < 64:
2178 bigPadding = 9
2179 paddings.update(
2180 {
2181 x
2182 for x in itertools.product(
2183 *([[0, bigPadding]] * (k_rank * 2))
2184 )
2185 }
2186 )
2187 bigStride = 8
2188 strides.update(
2189 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2190 )
2191 bigDilation = 7
2192 dilations.update(
2193 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2194 )
2195
2196 if error_name is None:
2197 # There are too many parameter combinations, so generate them sparsely,
2198 sparsity_factor = 120
2199 sparsity = TosaArgGen._calculate_sparsity(
2200 len(paddings) * len(strides) * len(dilations), sparsity_factor
2201 )
2202
2203 # Run through all the argument options creating valid test cases
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002204 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002205 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002206 for a in accum_dtypes:
2207 for s in sorted(list(strides)):
2208 for p in sorted(list(paddings)):
2209 for d in sorted(list(dilations)):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002210 if more_tests and (n % sparsity == 0):
2211 (
2212 p,
2213 remainders,
2214 outputs,
2215 outputs_no_stride,
2216 ) = get_conv_output_info(p, s, d, fix_up_padding)
2217 # Following is like checking each dimension N:
2218 # (ifm_shape[N+1] - 1 + p[N*2] + p[N*2+1]) > d[N] * (k_shape[N] - 1)
2219 if min(outputs_no_stride) <= 0:
2220 # Not a valid operation
2221 n += 1 # Increment count of tests
2222 continue
Tai Lyf36f2562024-03-14 16:21:29 +00002223
2224 if (
2225 # the parameters must produce integer exact output
2226 error_name != ErrorIf.ConvOutputShapeNonInteger
2227 and max(remainders) == 0
2228 ) or (
2229 error_name == ErrorIf.ConvOutputShapeNonInteger
2230 and max(remainders) > 0
2231 ):
2232 if (
2233 max_dim_size is not None
2234 and max(outputs) >= max_dim_size
2235 ):
2236 # Test will consume too much memory - skip it
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002237 logger.debug(
2238 "agConv: Convolution output too big - skipped"
2239 )
Tai Lyf36f2562024-03-14 16:21:29 +00002240 continue
2241
2242 # Compliance - number of dot product calculations
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00002243 if op["op"] == Op.DEPTHWISE_CONV2D:
Tai Lyf36f2562024-03-14 16:21:29 +00002244 # N*OH*OW*C*M
2245 dots = gtu.product(
2246 (ifm_shape[0], *outputs, *filter_shape[2:])
2247 )
2248 else:
2249 # N*OH*OW*OC or N*OD*OH*OW*OC
2250 dots = gtu.product(
2251 (ifm_shape[0], *outputs, filter_shape[0])
2252 )
2253 args_dict = {
2254 "acc_type": a,
2255 "stride": s,
2256 "pad": p,
2257 "dilation": d,
2258 "kernel": k_shape,
2259 "ks": k_size,
2260 "dot_products": dots,
2261 "shape": ifm_shape,
2262 }
2263
2264 # Support for larger values than 9 needs different delimiter
2265 delim = "" if max(s + p + d) <= 9 else "x"
2266 arg_list.append(
2267 (
2268 "acc{}_st{}_pad{}_dilat{}".format(
2269 testGen.typeStr(a),
2270 delim.join([str(x) for x in s]),
2271 delim.join([str(x) for x in p]),
2272 delim.join([str(x) for x in d]),
2273 ),
2274 args_dict,
2275 )
2276 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002277 if (
2278 error_name
Jeremy Johnson87460262024-03-25 09:46:02 +00002279 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002280 ):
2281 # Found enough errors
2282 logger.debug(
2283 f"Skipping creating more conv error tests for {error_name}"
2284 )
2285 more_tests = False
Tai Lyf36f2562024-03-14 16:21:29 +00002286 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002287
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002288 arg_list = TosaArgGen._add_data_generators(
2289 testGen,
2290 opName,
evacha019c96eef2024-02-07 11:21:55 +00002291 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002292 dtypes[0],
2293 arg_list,
2294 error_name,
2295 )
2296 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002297 return arg_list
2298
2299 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002300 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002301
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002302 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002303 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002304
2305 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002306 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002307 elif error_name == ErrorIf.WrongInputType:
2308 # Pick some potentially correct output dtype if input type is incorrect
2309 accum_dtype = DType.INT32
2310 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002311 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002312
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002313 # Set up compliance info
2314 args_dict = {
2315 "acc_type": accum_dtype,
2316 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2317 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2318 "shape": shapeList[0],
2319 }
2320
2321 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2322
2323 arg_list = TosaArgGen._add_data_generators(
2324 testGen,
2325 opName,
evacha019c96eef2024-02-07 11:21:55 +00002326 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002327 input_dtype,
2328 arg_list,
2329 error_name,
2330 )
2331 # Return list of tuples: (arg_str, args_dict)
2332 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002333
2334 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002335 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002336 # Get valid accumulate type(s)
2337 if dtype == DType.INT8:
2338 accum_dtypes = [DType.INT32]
2339 elif dtype == DType.INT16:
2340 accum_dtypes = [DType.INT48]
2341 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002342 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002343 elif dtype == DType.BF16:
2344 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002345 elif dtype == DType.FP32:
2346 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002347 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2348 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002349 elif error_name is None:
2350 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2351
2352 if error_name == ErrorIf.WrongOutputType:
2353 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002354 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002355 elif error_name == ErrorIf.WrongInputType:
2356 # Pick some potentially correct output dtype if input type is incorrect
2357 accum_dtypes = [DType.INT32]
2358
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002359 # Set up compliance info
2360 args_dict = {
2361 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2362 # Set dot_products = N*H*W
2363 "dot_products": gtu.product(
2364 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2365 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002366 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002367 }
2368
2369 # Create arg tuple of string and dict
2370 arg_list = []
2371 for a in accum_dtypes:
2372 d = args_dict.copy()
2373 d["acc_type"] = a
2374 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002375
2376 arg_list = TosaArgGen._add_data_generators(
2377 testGen,
2378 opName,
evacha019c96eef2024-02-07 11:21:55 +00002379 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002380 dtype,
2381 arg_list,
2382 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002383 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002384 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002385 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002386
2387 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002388 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002389 arg_list = []
2390
Jeremy Johnson0c716862023-04-13 17:18:19 +01002391 if testGen.args.level8k and error_name is not None:
2392 # Don't produce negative large tests
2393 return arg_list
2394
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002395 ifm_shape = shapeList[0]
2396 filter_shape = shapeList[1]
2397
Tai Lyf36f2562024-03-14 16:21:29 +00002398 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2399
2400 if error_name == ErrorIf.WrongAccumulatorType:
2401 accum_dtypes = (
2402 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2403 )
James Ward8b390432022-08-12 20:48:56 +01002404
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002405 # Must be rank 4
2406 if error_name != ErrorIf.WrongRank:
2407 assert len(ifm_shape) == 4
2408 assert len(filter_shape) == 4
2409
Jeremy Johnson0c716862023-04-13 17:18:19 +01002410 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002411 # compliance size - KS
2412 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002413
Jeremy Johnson0c716862023-04-13 17:18:19 +01002414 if not testGen.args.level8k:
2415 # Generate comprehensive argument lists
2416 # - except for named errors, which use specific invalid value(s)
2417 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2418 if error_name == ErrorIf.PadLargerEqualKernel:
2419 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002420 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002421 else:
2422 p_vals = [
2423 x
2424 for x in range(
2425 smallest_padding_size, testGen.args.max_conv_padding + 1
2426 )
2427 ]
2428 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2429 if error_name == ErrorIf.StrideSmallerOne:
2430 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002431 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002432 else:
2433 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2434 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002435
Jeremy Johnson0c716862023-04-13 17:18:19 +01002436 if not error_name and testGen.args.oversize:
2437 # add some oversize argument values
2438 if max(ifm_shape) < 64:
2439 bigPadding = 9
2440 paddings.update(
2441 {
2442 x
2443 for x in itertools.product(
2444 *([[smallest_padding_size, bigPadding]] * 4)
2445 )
2446 }
2447 )
2448 bigStride = 8
2449 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2450
2451 # There are too many parameter combinations, so generate them sparsely,
2452 # very sparse for negative tests
2453 sparsity_factor = 2 if error_name else 10
2454 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2455 # If there are only a small number of tests, just select them all
2456 if sparsity < 13:
2457 sparsity = 1
2458 # To get a variety of parameter combinations sparsity should not be a
2459 # multiple of 2, 3 or 5
2460 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2461 sparsity += 1
2462 else:
2463 # Only test 8k levels boundaries
2464 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2465 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2466 bigPadding = bigKernel
2467
2468 pad_shape = [0] * (len(k_shape) * 2)
2469 stride_shape = [1] * len(k_shape)
2470 # The point at which input dimension combined with the stride will
2471 # create large output sizes!
2472 LARGE_SIZE = 2
2473 for idx in range(len(k_shape)):
2474 pad_offset = idx * 2
2475 if k_shape[idx] == bigKernel:
2476 # Set large stride
2477 stride_shape[idx] = bigKernel
2478 # Use negative output padding to reduce shape size
2479 pad_shape[pad_offset] = -(bigPadding - 1)
2480 if ifm_shape[idx + 1] > LARGE_SIZE:
2481 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2482 else:
2483 # The other dimension should be the bigKernel
2484 alt_idx = 1 - idx
2485 if (
2486 k_shape[alt_idx] == bigKernel
2487 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2488 ):
2489 # As the input is small, the large stride won't
2490 # affect the output so we can add some padding
2491 pad_shape[pad_offset + 1] = bigPadding
2492
2493 strides = {tuple(stride_shape)}
2494 paddings = {tuple(pad_shape)}
2495
2496 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002497 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002498
2499 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002500 for a in accum_dtypes:
2501 for s in sorted(list(strides)):
2502 for p in sorted(list(paddings)):
2503 if n % sparsity == 0:
2504 # Determine the output shape
2505 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2506 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2507 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002508
Tai Lyf36f2562024-03-14 16:21:29 +00002509 # N*OH*OW*OC
2510 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2511 args_dict = {
2512 "acc_type": a,
2513 "stride": s,
2514 "pad": p,
2515 "kernel": k_shape,
2516 "ks": k_size,
2517 "dot_products": dots,
2518 "shape": ifm_shape,
2519 "out_shape": os,
2520 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002521
Tai Lyf36f2562024-03-14 16:21:29 +00002522 # Support for larger values than 9 needs different delimiter
2523 delim = "" if max(s + p) <= 9 else "x"
2524 arg_list.append(
2525 (
2526 "acc{}_st{}_pad{}_os{}".format(
2527 testGen.typeStr(a),
2528 delim.join([str(x) for x in s]),
2529 delim.join([str(x) for x in p]),
2530 "x".join([str(x) for x in os]),
2531 ),
2532 args_dict,
2533 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002534 )
Tai Lyf36f2562024-03-14 16:21:29 +00002535 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002536
Jeremy Johnson95a67102024-01-10 14:16:39 +00002537 arg_list = TosaArgGen._add_data_generators(
2538 testGen,
2539 opName,
evacha019c96eef2024-02-07 11:21:55 +00002540 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002541 dtypes[0],
2542 arg_list,
2543 error_name,
2544 )
2545 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002546 return arg_list
2547
2548 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002549 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002550 rank = len(shapeList[0])
2551
Jeremy Johnson30a36842024-03-27 15:04:07 +00002552 if error_name is None and testGen.args.oversize:
2553 pad_values = [6, 7, 10, 13]
2554 elif error_name == ErrorIf.PadSmallerZero:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002555 pad_values = [x for x in range(-2, 0)]
Jeremy Johnson30a36842024-03-27 15:04:07 +00002556 else:
2557 # Exhaustively test combinations of padding on each side of each dimension
2558 # - the range of padding values is defined by pad_min and pad_max
2559 pad_min, pad_max = 0, 1
2560 pad_values = [x for x in range(pad_min, pad_max + 1)]
2561
2562 # Calculate pad combinations
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002563 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2564 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2565
2566 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002567 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002568 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002569 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002570 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002571 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002572 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002573 assert error_name == ErrorIf.WrongInputType
2574 pad_const_int = 0
2575 pad_const_fp = 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002576
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002577 list_shape_pad_values = list(shape_pad_values)
2578 # If we are producing tests for rank 6 or greater use sparsity
2579 if len(list_shape_pad_values) > 1024:
2580 sparsity_factor = 2 if error_name else 120
2581 sparsity = TosaArgGen._calculate_sparsity(
2582 len(list_shape_pad_values), sparsity_factor
2583 )
2584 else:
2585 sparsity = 1
2586
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002587 # Build arg list
2588 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002589 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002590 paddings = list(paddings)
2591 args_valid = True
2592
2593 if error_name == ErrorIf.PadSmallerZero:
2594 # Prevent negative output shapes while ensuring still testing for negative padding
2595 for i in range(rank):
2596 dim_after_padding = (
2597 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2598 )
2599 if dim_after_padding < 1:
2600 paddings[i] = (0, 0)
2601 if all([p > -1 for p in paddings[i]]):
2602 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002603 if args_valid and n % sparsity == 0:
Jeremy Johnson30a36842024-03-27 15:04:07 +00002604 # Work out name
2605 pad_list = []
James Ward8b390432022-08-12 20:48:56 +01002606 for r in range(rank):
Jeremy Johnson30a36842024-03-27 15:04:07 +00002607 pad_list.extend(paddings[r])
2608
2609 delim = "" if max(pad_list) <= 9 else "x"
2610 name = "pad{}".format(delim.join([str(x) for x in pad_list]))
2611
2612 args_dict = {
2613 "pad": np.array(paddings),
2614 "pad_const_int": pad_const_int,
2615 "pad_const_fp": pad_const_fp,
2616 }
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002617 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002618
2619 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002620 logger.debug(
2621 f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
2622 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002623
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002624 arg_list = TosaArgGen._add_data_generators(
2625 testGen,
2626 opName,
evacha019c96eef2024-02-07 11:21:55 +00002627 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002628 dtype,
2629 arg_list,
2630 error_name,
2631 )
2632
2633 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002634 return arg_list
2635
2636 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002637 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002638 arg_list = []
2639
2640 shape = shapeList[0]
2641 if error_name != ErrorIf.WrongRank:
2642 assert len(shape) == 4
2643
Jeremy Johnson0c716862023-04-13 17:18:19 +01002644 test_level8k = testGen.args.level8k and error_name is None
2645
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002646 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002647 startKernel = 2
2648 startPad = 0
2649 if not test_level8k:
2650 # Generate comprehensive argument lists
2651 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2652 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2653 # Stride must be greater than 1 to force non-integer error
2654 s_vals = [
2655 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2656 ]
2657 strides = {x for x in itertools.product(*([s_vals] * 2))}
2658 k_vals = [
2659 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2660 ]
2661 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2662 max_dim_size = None
2663 else:
2664 # Only test 8k levels
2665 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2666 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2667 strides = {(1, bigStride), (bigStride, 4)}
2668 kernels = {(1, bigKernel), (bigKernel, 3)}
2669 paddings = set()
2670 for s in sorted(list(strides)):
2671 for k in sorted(list(kernels)):
2672 padding = []
2673 for idx in range(len(k)):
2674 total_padding = s[idx] - shape[idx + 1] + k[idx]
2675 while total_padding < 0:
2676 # Must meet: shape + padding > kernel
2677 total_padding += s[idx]
2678 if total_padding < k[idx]:
2679 padding.extend([0, total_padding])
2680 else:
2681 # Note this may produce padding >= k[idx] which is not
2682 # allowed - but will be ignored in the creation loop below
2683 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2684 paddings.add(tuple(padding))
2685 # Create a limit for the output dimensions size
2686 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002687
James Ward8b390432022-08-12 20:48:56 +01002688 if opName == "max_pool2d":
2689 accum_dtypes = [None] # max_pool has no accumulate dtype
2690 elif dtype == DType.INT8 or dtype == DType.INT16:
2691 accum_dtypes = [DType.INT32]
2692 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002693 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002694 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002695 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002696 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2697 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002698 elif error_name is None:
2699 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2700 else:
2701 # Set to something for the ErrorIf case which has
2702 # incorrect input data-type
2703 accum_dtypes = [DType.INT32]
2704
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002705 if error_name == ErrorIf.WrongAccumulatorType:
2706 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2707
Jeremy Johnson0c716862023-04-13 17:18:19 +01002708 if not test_level8k:
2709 if testGen.args.oversize:
2710 # add some oversize argument values
2711 bigStride = 7
2712 bigKernel = 9
2713 strides.update(
2714 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002715 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002716 kernels.update(
2717 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2718 )
2719 if max(shape) < 64:
2720 # padding must be less than the kernel size
2721 bigPadding = bigKernel - 1
2722 paddings.update(
2723 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2724 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002725
Jeremy Johnson87460262024-03-25 09:46:02 +00002726 if error_name:
2727 # Cycle through all error_if tests but we only keep the first few
2728 sparsity = 1
2729 else:
2730 # There are too many parameter combinations, so generate them sparsely
2731 sparsity_factor = 500
2732 sparsity = (
2733 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2734 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002735 else:
2736 # We have already limited test output combinations for 8k tests
2737 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002738
James Ward8b390432022-08-12 20:48:56 +01002739 arg_str = (
2740 "acc{}_st{}_kern{}_pad{}"
2741 if accum_dtypes[0] is not None
2742 else "st{}_kern{}_pad{}"
2743 )
2744
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002745 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002746 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002747 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002748
2749 # Support for larger values than 9 needs different delimiter
2750 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002751 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002752 delim.join([str(x) for x in stride]),
2753 delim.join([str(x) for x in kern]),
2754 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002755 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002756 args_dict = {
2757 "stride": stride,
2758 "pad": pad,
2759 "kernel": kern,
2760 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002761 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002762 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2763 }
James Ward8b390432022-08-12 20:48:56 +01002764
2765 if accum is not None:
2766 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002767 args_dict["acc_type"] = accum
2768 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002769
Jeremy Johnson87460262024-03-25 09:46:02 +00002770 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002771 n = 0
James Ward8b390432022-08-12 20:48:56 +01002772 for a in accum_dtypes:
2773 for s in sorted(list(strides)):
2774 for p in sorted(list(paddings)):
2775 for k in sorted(list(kernels)):
2776 if error_name in [
2777 ErrorIf.StrideSmallerOne,
2778 ErrorIf.KernelSmallerOne,
2779 ErrorIf.PadSmallerZero,
2780 ErrorIf.PadLargerEqualKernel,
2781 ]:
2782 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002783 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002784 )
James Ward8b390432022-08-12 20:48:56 +01002785 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002786 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002787 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002788 )
James Ward8b390432022-08-12 20:48:56 +01002789 elif (
Jeremy Johnson87460262024-03-25 09:46:02 +00002790 more_tests
2791 and n % sparsity == 0
James Ward8b390432022-08-12 20:48:56 +01002792 # padding must not exceed the kernel size
2793 and p[0] < k[0]
2794 and p[1] < k[0]
2795 and p[2] < k[1]
2796 and p[3] < k[1]
2797 # the padded shape must exceed the kernel size
2798 and (shape[1] + p[0] + p[1]) > k[0]
2799 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002800 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002801 partial_h = shape[1] + p[0] + p[1] - k[0]
2802 partial_w = shape[2] + p[2] + p[3] - k[1]
2803 remainder_h = partial_h % s[0]
2804 remainder_w = partial_w % s[1]
2805 output_h = partial_h // s[0] + 1
2806 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002807 logger.debug(
2808 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2809 )
James Ward8b390432022-08-12 20:48:56 +01002810 if (
2811 # the parameters must produce integer exact output
2812 error_name != ErrorIf.PoolingOutputShapeNonInteger
2813 and remainder_h == 0
2814 and remainder_w == 0
2815 ) or (
2816 error_name == ErrorIf.PoolingOutputShapeNonInteger
2817 and (remainder_h != 0 or remainder_w != 0)
2818 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002819 if (
2820 max_dim_size is not None
2821 and max(output_h, output_w) > max_dim_size
2822 ):
2823 # Test will consume too much memory - skip it
2824 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002825 # Dot products = N*OH*OW*C
2826 dp = gtu.product(
2827 (shape[0], output_h, output_w, shape[3])
2828 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002829 arg_list.append(
2830 get_arg_list_element(a, s, p, k, dp, shape)
2831 )
Jeremy Johnson87460262024-03-25 09:46:02 +00002832 if (
2833 error_name
2834 and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
2835 ):
2836 # Found enough errors
2837 logger.debug(
2838 f"Skipping creating more pooling error tests for {error_name}"
2839 )
2840 more_tests = False
2841
James Ward8b390432022-08-12 20:48:56 +01002842 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002843
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002844 # Now add data generator types
2845 arg_list = TosaArgGen._add_data_generators(
2846 testGen,
2847 opName,
evacha019c96eef2024-02-07 11:21:55 +00002848 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002849 dtype,
2850 arg_list,
2851 error_name,
2852 )
2853
2854 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002855 return arg_list
2856
2857 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002858 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002859 arg_list = []
2860
2861 # Enumerate the output types here
2862 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002863 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002864 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002865 dtypeList = [
2866 DType.BOOL,
2867 DType.INT16,
2868 DType.INT32,
2869 DType.FP16,
2870 DType.BF16,
2871 DType.FP32,
2872 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002873 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002874 dtypeList = [
2875 DType.BOOL,
2876 DType.INT8,
2877 DType.INT32,
2878 DType.FP16,
2879 DType.BF16,
2880 DType.FP32,
2881 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002882 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002883 dtypeList = [
2884 DType.BOOL,
2885 DType.INT8,
2886 DType.INT16,
2887 DType.FP16,
2888 DType.BF16,
2889 DType.FP32,
2890 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002891 elif inDtype == DType.BOOL:
2892 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002893 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002894 dtypeList = [
2895 DType.INT8,
2896 DType.INT16,
2897 DType.INT32,
2898 DType.FP32,
2899 DType.FP8E4M3,
2900 DType.FP8E5M2,
2901 ]
James Ward24dbc422022-10-19 12:20:31 +01002902 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002903 dtypeList = [
2904 DType.INT8,
2905 DType.INT16,
2906 DType.INT32,
2907 DType.FP32,
2908 DType.FP8E4M3,
2909 DType.FP8E5M2,
2910 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002911 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002912 dtypeList = [
2913 DType.INT8,
2914 DType.INT16,
2915 DType.INT32,
2916 DType.FP16,
2917 DType.BF16,
2918 DType.FP8E4M3,
2919 DType.FP8E5M2,
2920 ]
2921 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2922 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002923 elif error_name == ErrorIf.WrongInputType:
2924 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002925 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002926 else:
2927 raise Exception("Unexpected input dtype: {}".format(inDtype))
2928
2929 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002930 arg_list.append(
2931 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2932 )
2933
2934 # Now add data generator types
2935 arg_list = TosaArgGen._add_data_generators(
2936 testGen,
2937 opName,
evacha019c96eef2024-02-07 11:21:55 +00002938 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002939 dtype,
2940 arg_list,
2941 error_name,
2942 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002943
2944 return arg_list
2945
2946 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002947 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002948 arg_list = []
2949
2950 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002951 for outDtype in [
2952 DType.UINT8,
2953 DType.INT8,
2954 DType.INT16,
2955 DType.INT32,
2956 DType.UINT16,
2957 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002958 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002959 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002960 and error_name == ErrorIf.OutputZeroPointNotZero
2961 ):
2962 continue
2963 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002964 outDtype != DType.UINT16
2965 and error_name == ErrorIf.U16OutputZeroPointNotValid
2966 ) or (
2967 inDtype != DType.UINT16
2968 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002969 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002970 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002971 continue
2972 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002973 inDtype == DType.UINT8
2974 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002975 and error_name != ErrorIf.WrongOutputType
2976 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002977 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2978 continue
2979 if (
2980 inDtype not in [DType.INT8, DType.INT16]
2981 and outDtype == DType.UINT8
2982 and error_name != ErrorIf.WrongOutputType
2983 ):
2984 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2985 continue
2986 if (
2987 inDtype == DType.UINT16
2988 and outDtype != DType.INT16
2989 and error_name != ErrorIf.WrongOutputType
2990 ):
2991 # The only output dtype for UINT16 is INT16, skip all others
2992 continue
2993 if (
2994 inDtype != DType.INT16
2995 and outDtype == DType.UINT16
2996 and error_name != ErrorIf.WrongOutputType
2997 ):
2998 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002999 continue
3000 if (
3001 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003002 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003003 ):
3004 continue
3005
3006 for scale32 in [False, True]:
3007 if error_name == ErrorIf.ScaleTrue and not scale32:
3008 continue
3009 elif error_name == ErrorIf.ScaleNotTrue and scale32:
3010 continue
3011 for double_round in [False, True]:
3012 if error_name == ErrorIf.ScaleNotTrue and not double_round:
3013 continue
3014 for per_channel in [False, True]:
3015
3016 if (
3017 inDtype == DType.INT48
3018 and scale32
3019 and error_name != ErrorIf.ScaleTrue
3020 ):
3021 # Illegal condition. Must be scale32=False
3022 continue
3023 if (
3024 double_round
3025 and not scale32
3026 and error_name != ErrorIf.ScaleNotTrue
3027 ):
3028 # Illegal condition. ERROR_IF(!scale32 && double_round)
3029 continue
3030
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003031 if per_channel:
3032 nc = shapeList[0][-1]
3033 else:
3034 nc = 1
3035
3036 in_type_width = gtu.dtypeWidth(inDtype)
3037 out_type_width = gtu.dtypeWidth(outDtype)
3038
3039 # Calculate scale based on:
3040 # scale = a *(2^output_width)/(2^input_width))
3041
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003042 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003043 scale_arr = a * np.float32(
3044 (1 << out_type_width) / (1 << in_type_width)
3045 )
3046
3047 if scale32:
3048 # Cap the scaling at 2^31 - 1 for scale32
3049 scale_arr = np.clip(
3050 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
3051 )
3052 else:
3053 # Cap the scaling at 2^15 - 1 for scale16
3054 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3055
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003056 logger.debug(
3057 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
3058 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003059
3060 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3061 shift_arr = np.int32(np.zeros(shape=[nc]))
3062 for i in range(nc):
3063 (
3064 multiplier_arr[i],
3065 shift_arr[i],
3066 ) = TosaQuantGen.computeMultiplierAndShift(
3067 scale_arr[i], scale32
3068 )
3069
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003070 arg_list.append(
3071 (
3072 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01003073 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003074 int(scale32),
3075 int(double_round),
3076 int(per_channel),
3077 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00003078 {
3079 "output_dtype": outDtype,
3080 "scale": scale32,
3081 "double_round": double_round,
3082 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003083 "multiplier": multiplier_arr,
3084 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003085 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003086 )
3087 )
3088
Jeremy Johnson587cc842024-02-08 11:45:44 +00003089 arg_list = TosaArgGen._add_data_generators(
3090 testGen,
3091 opName,
evacha019c96eef2024-02-07 11:21:55 +00003092 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003093 inDtype,
3094 arg_list,
3095 error_name,
3096 )
3097 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003098 return arg_list
3099
3100 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003101 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003102 arg_list = []
3103
3104 if dtype is DType.INT32:
3105 for p in range(testGen.args.num_rand_permutations):
3106
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003107 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003108 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003109 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003110 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003111
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003112 arg_list = TosaArgGen._add_data_generators(
3113 testGen,
3114 opName,
evacha019c96eef2024-02-07 11:21:55 +00003115 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003116 dtype,
3117 arg_list,
3118 error_name,
3119 )
3120 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003121 return arg_list
3122
3123 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003124 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003125 arg_list = []
3126
Jeremy Johnson587cc842024-02-08 11:45:44 +00003127 for round in (True, False):
3128 args_dict = {
3129 "round": round,
3130 }
3131 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003132
Jeremy Johnson587cc842024-02-08 11:45:44 +00003133 arg_list = TosaArgGen._add_data_generators(
3134 testGen,
3135 opName,
evacha019c96eef2024-02-07 11:21:55 +00003136 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003137 dtype,
3138 arg_list,
3139 error_name,
3140 )
3141 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003142 return arg_list
3143
Luke Hutton57287132023-02-06 14:54:18 +00003144 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003145 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003146 arg_list = []
3147
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003148 shape = shapeList[0]
3149 dot_products = gtu.product(shape)
3150 ks = 2 * shape[1] * shape[2] # 2*H*W
3151 for inverse in (True, False):
3152 args_dict = {
3153 "dot_products": dot_products,
3154 "shape": shape,
3155 "ks": ks,
3156 "acc_type": dtype,
3157 "inverse": inverse,
3158 }
3159 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003160
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003161 arg_list = TosaArgGen._add_data_generators(
3162 testGen,
3163 opName,
evacha019c96eef2024-02-07 11:21:55 +00003164 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003165 dtype,
3166 arg_list,
3167 error_name,
3168 )
3169 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003170 return arg_list
3171
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003172 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003173 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003174 arg_list = []
3175
3176 shape = shapeList[0]
3177 dot_products = gtu.product(shape)
3178 ks = shape[1] * shape[2] # H*W
3179 args_dict = {
3180 "dot_products": dot_products,
3181 "shape": shape,
3182 "ks": ks,
3183 "acc_type": dtype,
3184 }
3185 arg_list.append(("", args_dict))
3186
3187 arg_list = TosaArgGen._add_data_generators(
3188 testGen,
3189 opName,
evacha019c96eef2024-02-07 11:21:55 +00003190 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003191 dtype,
3192 arg_list,
3193 error_name,
3194 )
3195 # Return list of tuples: (arg_str, args_dict)
3196 return arg_list
3197
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003198 # Helper function for reshape. Gets some factors of a larger number.
3199 @staticmethod
3200 def getFactors(val, start=1):
3201 factors = []
3202
3203 for i in range(start, int(np.sqrt(val)) + 1):
3204 if (val % i) == 0:
3205 factors.append(i)
3206
3207 return factors
3208
3209 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003210 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003211 arg_list = []
3212
3213 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003214 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003215 factors = TosaArgGen.getFactors(totalElements)
3216
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003217 # Find new shapes up to the number of permutations asked for
3218 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003219 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003220 # Rank from 1 to MAX_TENSOR_RANK
3221 newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003222 if len(factors) < newRank:
3223 continue
3224
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003225 # escape_counter limits the generation of new shapes to a reasonable time
3226 for escape_counter in range(100):
3227
3228 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003229 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003230 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003231 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003232 for i in range(1, newRank):
3233 # pick rank-1 factors
3234 newShape.append(shuffledFactors[0])
3235 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003236 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003237 TosaArgGen.getFactors(remainingElements)
3238 )
3239 newShape.append(remainingElements)
3240
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003241 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003242 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003243 for name, args_dict in arg_list:
3244 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003245 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003246 break
3247
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003248 if not duplicate:
3249 outShape = "x".join([str(x) for x in newShape])
3250 arg_list.append(
3251 (
3252 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3253 {"new_shape": newShape},
3254 )
3255 )
3256 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003257 break
3258
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003259 # Now add data generator types
3260 arg_list = TosaArgGen._add_data_generators(
3261 testGen,
3262 opName,
evacha019c96eef2024-02-07 11:21:55 +00003263 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003264 dtype,
3265 arg_list,
3266 error_name,
3267 )
3268
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003269 return arg_list
3270
3271 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003272 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003273 arg_list = []
3274
3275 ifm_shape = shapeList[0]
3276
3277 if error_name == ErrorIf.IndexOutsideBounds:
3278 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3279 incorrect_small_index = range(-len(ifm_shape), 0)
3280 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3281 permutations.extend(
3282 [p for p in itertools.permutations(incorrect_small_index)]
3283 )
3284 elif error_name == ErrorIf.IndexUsedTwice:
3285 # Create list with a duplicated index
3286 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003287 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003288 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3289 permutations = [p for p in itertools.permutations(perm_range)]
3290
3291 else:
3292 # Get all permutations
3293 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3294
3295 # Limit to possible permutations from shape dimension or argument setting
3296 limit = min(len(permutations), testGen.args.num_rand_permutations)
3297
3298 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003299 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003300
3301 # Create list of required amount of permutations
3302 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003303 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003304 for p in range(limit)
3305 ]
evacha0198477222024-01-26 12:25:32 +00003306 # Now add data generator types
3307 arg_list = TosaArgGen._add_data_generators(
3308 testGen,
3309 opName,
evacha019c96eef2024-02-07 11:21:55 +00003310 shapeList,
evacha0198477222024-01-26 12:25:32 +00003311 dtype,
3312 arg_list,
3313 error_name,
3314 )
3315 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003316 return arg_list
3317
3318 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003319 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003320 arg_list = []
3321
3322 ifm_shape = shapeList[0]
3323 rank = len(ifm_shape)
3324
3325 for p in range(testGen.args.num_rand_permutations):
3326 start = []
3327 size = []
3328
3329 valid = True
3330
3331 for i in range(rank):
3332 if ifm_shape[i] > 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003333 start.append(rng.randInt(0, ifm_shape[i]))
3334 size.append(rng.randInt(0, ifm_shape[i] - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003335
3336 # Invalid slice size?
3337 if size[i] == 0:
3338 valid = False
3339 else:
3340 start.append(0)
3341 size.append(1)
3342
3343 if valid:
3344 # If ERROR_IF test required then incorrect start, size will be returned
3345 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003346 rng, error_name, ifm_shape, start, size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347 )
evacha017f7d4252024-01-24 12:08:09 +00003348 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3349 # Now add data generator types
3350 arg_list = TosaArgGen._add_data_generators(
3351 testGen,
3352 opName,
evacha019c96eef2024-02-07 11:21:55 +00003353 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003354 dtype,
3355 arg_list,
3356 error_name,
3357 )
3358 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003359 return arg_list
3360
3361 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003362 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003363 arg_list = []
3364
3365 ifm_shape = shapeList[0]
3366 rank = len(ifm_shape)
3367
3368 for p in range(testGen.args.num_rand_permutations):
3369
3370 # Pick a few random, but small multiple values
3371 # because otherwise this has a tendency to generate
3372 # enormous tensors
3373 multiples = []
3374 for i in range(rank):
3375 if ifm_shape[i] > 1000:
3376 # Multiple of 1 if ifm_shape dimension is large to reduce
3377 # tensor size
3378 multiples.append(1)
3379 elif max(ifm_shape) > 1000:
3380 multiples.append(2)
3381 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003382 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003383 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003384
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003385 # Now add data generator types
3386 arg_list = TosaArgGen._add_data_generators(
3387 testGen,
3388 opName,
evacha019c96eef2024-02-07 11:21:55 +00003389 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003390 dtype,
3391 arg_list,
3392 error_name,
3393 )
3394 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003395 return arg_list
3396
3397 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003398 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003399 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003400 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003401
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003402 def get_aspect_ratio_resize_params():
3403 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003404 aspect_ratio = rng.choice(common_aspect_ratios)
3405 invert = rng.choice((False, True))
3406 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003407
3408 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3409 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3410 scale_y_d = scale_x_d = 1
3411 offset_x = offset_y = 0
3412
3413 if letterbox:
3414 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003415 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003416 border_x = 0
3417 else:
3418 # Pillarboxing
3419 border_y = 0
3420 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003421 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003422
3423 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3424 offset = (offset_y, offset_x)
3425 border = (border_y, border_x)
3426
3427 return scale, offset, border
3428
3429 def get_upscale_downscale_params():
3430 valid_params = False
3431 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003432 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003433
3434 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003435 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003436
3437 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003438 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003439 scale_x_d = scale_y_d = 1
3440 scale_x_n = scale_y_n = (
3441 1 << shift if origin_sampling else 2 << shift
3442 )
3443 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3444 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3445 else:
3446 scale_x_n = 1
3447 scale_y_n = 1
3448
3449 # Return list of valid scale_*_d values (max value 4) given input dim shape
3450 def get_valid_denom(ifm_dim):
3451 return [x for x in range(1, 5) if ifm_dim % x == 1]
3452
3453 # Generate list of valid downscale values and choose one randomly
3454 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3455 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3456
3457 if not valid_scale_y_ds and not valid_scale_x_ds:
3458 # Bad parameters, skip
3459 continue
3460
3461 if not valid_scale_y_ds:
3462 scale_y_d = 1
3463 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003464 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003465
3466 if not valid_scale_x_ds:
3467 scale_x_d = 1
3468 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003469 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003470
3471 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003472 offset_y = rng.randInt(0, 16 * scale_y_n)
3473 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003474 valid_params = True
3475
3476 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3477 offset = (offset_y, offset_x)
3478 border = (border_y, border_x)
3479 return scale, offset, border
3480
3481 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003482 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3483 scale = scale_n / scale_d
3484 if scale > max_scale:
3485 factor = scale / max_scale
3486 new_scale_d = math.ceil(scale_d * factor)
3487 assert scale_n / new_scale_d <= max_scale
3488 scale_d = new_scale_d
3489 return scale_d
3490
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003491 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003492 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3493 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003494
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003495 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3496 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003497
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003498 scale_y_d = fix_scale_to_max_scale(
3499 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3500 )
3501 scale_x_d = fix_scale_to_max_scale(
3502 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3503 )
3504
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003505 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003506 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3507 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3508 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3509 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003510
3511 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3512 offset = (offset_y, offset_x)
3513 border = (border_y, border_x)
3514 return scale, offset, border
3515
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003516 def get_level_8k_params():
3517 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003518 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003519 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3520 )
3521 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3522 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003523 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003524 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003525 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003526 if switch:
3527 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3528 else:
3529 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3530
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003531 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3532 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003533 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003534 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3535 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003536 border = (border_y, border_x)
3537 return scale, offset, border
3538
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003539 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003540 # Exclude illegal {mode, type} configurations. Pick legal output types
3541 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3542 outputDTypeList = [DType.INT8]
3543 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3544 outputDTypeList = [DType.INT16]
3545 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3546 outputDTypeList = [DType.INT32]
3547 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3548 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003549 elif dtype == DType.FP16:
3550 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003551 elif dtype == DType.BF16:
3552 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003553 elif dtype == DType.FP32:
3554 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003555 elif dtype == DType.FP8E4M3:
3556 outputDTypeList = [DType.FP8E4M3]
3557 elif dtype == DType.FP8E5M2:
3558 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003559 elif error_name == ErrorIf.WrongInputType:
3560 # If an incorrect input type is used then we set a 'correct'
3561 # output type to avoid other errors
3562 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3563 else:
3564 continue
3565
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003566 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3567
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003568 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003569 perm = 0
3570 while perm < testGen.args.num_rand_permutations:
3571 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003572 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003573 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003574 (
3575 get_rand_params,
3576 get_upscale_downscale_params,
3577 get_aspect_ratio_resize_params,
3578 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003579 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003580 scale, offset, border = _rnd_param_fn()
3581 else:
3582 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003583
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003584 # Expand params for bounds-checking
3585 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3586 (offset_y, offset_x) = offset
3587 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003588
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003589 # Make sure output dimensions OH and OW are integers
3590 partial_output_y = (
3591 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3592 )
3593 partial_output_x = (
3594 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3595 )
3596 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003597 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003598 if (
3599 partial_output_y % scale_y_d == 0
3600 and partial_output_x % scale_x_d == 0
3601 ):
3602 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003603 if perm > 0:
3604 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003605 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003606 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003607 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003608 while partial_output_y % scale_y_d != 0:
3609 scale_y_d -= 1
3610 while partial_output_x % scale_x_d != 0:
3611 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003612 # Make sure we are still within max scaling
3613 if (
3614 scale_y_n / scale_y_d
3615 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3616 scale_x_n / scale_x_d
3617 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3618 # Skip the test as it is using too large a scaling factor
3619 if perm > 0:
3620 perm += 1
3621 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003622
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003623 output_y = partial_output_y // scale_y_d + 1
3624 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003625
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003626 if (
3627 output_y >= testGen.args.max_resize_output_dim
3628 or output_x >= testGen.args.max_resize_output_dim
3629 ) and error_name is None:
3630 # Skip positive test if output dim will be too high
3631 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003632 if not testGen.args.level8k or perm > 0:
3633 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003634 continue
3635
3636 if (
3637 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003638 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003639 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003640 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003641 ):
3642 # Output dimensions out of scope
3643 if error_name is not None and perm > 0:
3644 # As long as we have one ERROR_IF test, don't worry
3645 # about creating all the other permutations
3646 perm += 1
3647 continue
3648
3649 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3650 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003651 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003652 and output_y - scale_y_d < 1
3653 )
3654 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003655 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003656 and output_x - scale_x_d < 1
3657 )
3658 ):
3659 # Can't create a negative test with these params as it
3660 # will create invalid output size
3661 if perm > 0:
3662 perm += 1
3663 continue
3664
3665 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3666 offset = [offset_y, offset_x]
3667 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003668
3669 # Common for all data types
3670 if error_name is not None:
3671 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003672 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003673 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003674 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003675 outputDTypeNew,
3676 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003677 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003678 error_name,
3679 mode,
3680 dtype,
3681 shapeList,
3682 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003683 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003684 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003685 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003686 )
3687 else:
3688 outputDTypeNew = outputDType
3689
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003690 arg_to_append = (
3691 arg_str.format(
3692 "N" if mode == ResizeMode.NEAREST else "B",
3693 testGen.typeStr(outputDTypeNew),
3694 scale[0],
3695 scale[1],
3696 scale[2],
3697 scale[3],
3698 offset[0],
3699 offset[1],
3700 border[0],
3701 border[1],
3702 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003703 {
3704 "mode": mode,
3705 "scale": scale,
3706 "offset": offset,
3707 "border": border,
3708 "output_dtype": outputDTypeNew,
3709 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003710 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003711 if arg_to_append in arg_list:
3712 # Skip already generated test params
3713 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003715 # Valid permutation
3716 perm += 1
3717 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003718
3719 # Now add data generator types
3720 arg_list = TosaArgGen._add_data_generators(
3721 testGen,
3722 opName,
evacha019c96eef2024-02-07 11:21:55 +00003723 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003724 dtype,
3725 arg_list,
3726 error_name,
3727 )
3728 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003729 return arg_list
3730
3731 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003732 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003733 arg_list = []
3734
3735 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003736 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003737 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003738 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003739 # Make sure all slopes are within REQUIRE min/max 16-bit int
3740 for idx in range(len(table) - 1):
3741 slope = table[idx + 1] - table[idx]
3742 # Alter the next table entry to force the slope to be ok
3743 if slope > 32767:
3744 table[idx + 1] -= slope - 32767
3745 if slope < -32768:
3746 table[idx + 1] -= slope + 32768
3747 slope = table[idx + 1] - table[idx]
3748 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003749 arg_list.append(
3750 (
3751 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003752 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003753 )
3754 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003755 # Now add data generator types
3756 arg_list = TosaArgGen._add_data_generators(
3757 testGen,
3758 opName,
evacha019c96eef2024-02-07 11:21:55 +00003759 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003760 dtype,
3761 arg_list,
3762 error_name,
3763 )
3764 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003765 return arg_list
3766
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003767 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003768 # CondIf generates the condition values here.
3769 # Convert to tensors in the build function, along with the
3770 # then and else blocks
3771 arg_list = []
3772
3773 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003774 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003775
Jeremy Johnson587cc842024-02-08 11:45:44 +00003776 # Now add data generator types
3777 arg_list = TosaArgGen._add_data_generators(
3778 testGen,
3779 opName,
evacha019c96eef2024-02-07 11:21:55 +00003780 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003781 dtype,
3782 arg_list,
3783 error_name,
3784 )
3785 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003786 return arg_list
3787
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003788 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 # While loop: 0 iterations, 1, more than 1
3790 arg_list = []
3791
Jeremy Johnson587cc842024-02-08 11:45:44 +00003792 for iterations in [0, 1, 4]:
3793 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003794
Jeremy Johnson587cc842024-02-08 11:45:44 +00003795 # Now add data generator types
3796 arg_list = TosaArgGen._add_data_generators(
3797 testGen,
3798 opName,
evacha019c96eef2024-02-07 11:21:55 +00003799 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003800 dtype,
3801 arg_list,
3802 error_name,
3803 )
3804 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003805 return arg_list