blob: 504bfa9f4caa876e8295791c9c84923dd5e34fd4 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
262 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100263 shape_list = []
264
265 # Choose one of the inputs to broadcast
266 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100267 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
268 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000269
Jeremy Johnson0a042992024-02-28 13:20:05 +0000270 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100271 shape_bcast = shape.copy()
272
Jerry Ge135c9552023-05-23 20:59:32 +0000273 # To test broadcasting, the chosen fuzz index dimension should not be 1
274 if shape_bcast[fuzz_idx] == 1:
275 shape_bcast[fuzz_idx] += 1
276
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100277 # If the chosen input, pick a random index to broadcast
278 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000279 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100280 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100281 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100282 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100283 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 )
285 if rank != 1:
286 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100287 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000289 elif error_name == ErrorIf.BroadcastShapesMismatch:
290 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 else:
292 shape_bcast[fuzz_idx] = 1
293
294 shape_list.append(shape_bcast)
295
296 return shape_list
297
298 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100299 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000300 pl, const = op["operands"]
301 num_shapes = pl + const
302 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100303 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000304 )
305
306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100307 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000308 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 shape_list = TosaTensorGen._get_broadcast_shapes(
310 testGen, rng, 2, rank, error_name
311 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000312 # Add a single dimension tensor for shift
313 shape_list.append([1])
314 return shape_list
315
316 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100317 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318 pl, const = op["operands"]
319
320 if error_name != ErrorIf.WrongRank:
321 assert rank == 4
322
323 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000325 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326
327 # Constrict the overall size of the shape when creating ERROR_IF tests
328 if error_name:
329 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
330 ifm_shape, max_dim=24, max_items=10000
331 )
332
333 # Get the filter height/width from the operator parameters
334 filter_hw = op["filter"]
335
336 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100337 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # The filter dimensions are OHWI
340 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
341
342 # The bias is OC
343 bias_shape = np.asarray([ofm_depth])
344
345 return [ifm_shape, filter_shape, bias_shape]
346
347 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100348 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 pl, const = op["operands"]
350
351 if error_name != ErrorIf.WrongRank:
352 assert rank == 5
353
354 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100355 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000356 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100357
358 # Constrict the overall size of the shape when creating ERROR_IF tests
359 if error_name:
360 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
361 ifm_shape, max_dim=24, max_items=10000
362 )
363
364 # Get the filter depth/height/width from the operator parameters
365 filter_dhw = op["filter"]
366
367 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100368 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100369
370 # The filter dimensions are ODHWI
371 filter_shape = np.asarray(
372 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
373 )
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_channel])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100381 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386
387 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100388 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000389 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100390
391 # Constrict the overall size of the shape when creating ERROR_IF tests
392 if error_name:
393 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
394 ifm_shape, max_dim=24, max_items=10000
395 )
396
397 # Get the filter height/width from the operator parameters
398 filter_hw = op["filter"]
399
400 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100401 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100402
403 # The filter dimensions are OHWI
404 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
405
406 # The bias is OC
407 bias_shape = np.asarray([ofm_depth])
408
409 return [ifm_shape, filter_shape, bias_shape]
410
411 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100412 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100413 pl, const = op["operands"]
414
415 if error_name != ErrorIf.WrongRank:
416 assert rank == 4
417 assert pl == 1 and const == 2
418
419 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100420 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000421 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
426 ifm_shape, max_dim=24, max_items=10000
427 )
428
429 # Get the filter height/width from the operator parameters
430 # Filter is KH, HW, C, M
431 filter_hw = op["filter"]
432
433 # Generate a random OFM depth, but don't let it get too big because
434 # the output depth is M * C
435 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100436 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100437 ) + 1
438
439 # The filter dimensions are HWCM
440 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
441
442 # The bias is M * C
443 bias_shape = np.asarray([ifm_shape[3] * filter_m])
444
445 return [ifm_shape, filter_shape, bias_shape]
446
447 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100448 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 2 and const == 0
454
455 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100456 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 inc_h = 2 if ifm_shape[1] == 1 else 1
469 inc_w = 2 if ifm_shape[2] == 1 else 1
470 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100471 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000472 ifm_shape[1] += selected_inc[0]
473 ifm_shape[2] += selected_inc[1]
474
475 ifm_shape = testGen.constrictBatchSize(ifm_shape)
476
477 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
478 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100479 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000480 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100481 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000482 ifm_shapes[modify_shape][modify_dim] *= 2
483
484 return [ifm_shapes[0], ifm_shapes[1]]
485
486 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100487 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000488 pl, const = op["operands"]
489
490 if error_name != ErrorIf.WrongRank:
491 assert rank == 3
492 assert pl == 1 and const == 0
493
494 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100495 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000496
497 # Select nearest lower power of two from input height and width
498 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
499 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
500
501 # Constrict the overall size of the shape when creating ERROR_IF tests
502 if error_name:
503 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
504
505 # Generate an invalid kernel that is not a power of two
506 if error_name == ErrorIf.KernelNotPowerOfTwo:
507 # We must increment by 2 if current size is 1
508 inc_h = 2 if ifm_shape[1] == 1 else 1
509 inc_w = 2 if ifm_shape[2] == 1 else 1
510 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100511 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000512 ifm_shape[1] += selected_inc[0]
513 ifm_shape[2] += selected_inc[1]
514
James Ward30124a82023-02-02 14:56:33 +0000515 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000516
517 return [ifm_shape]
518
519 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100520 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100521 pl, const = op["operands"]
522
523 if error_name != ErrorIf.WrongRank:
524 assert rank == 2
525
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527
528 # Constrict the overall size of the shape when creating ERROR_IF tests
529 if error_name:
530 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
531
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100532 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100533 low=testGen.args.tensor_shape_range[0],
534 high=testGen.args.tensor_shape_range[1],
535 size=1,
536 )[0]
537 filter_shape = np.asarray([filter_oc, input_shape[1]])
538
539 bias_shape = np.asarray([filter_oc])
540
541 return [input_shape, filter_shape, bias_shape]
542
543 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100544 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100545 pl, const = op["operands"]
546
547 if error_name != ErrorIf.WrongRank:
548 assert rank == 3
549 assert pl == 2 and const == 0
550
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100551 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100552
553 # Constrict the overall size of the shape when creating ERROR_IF tests
554 if error_name:
555 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
556
557 # Get a random number for b_oc even if target shape is defined
558 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100559 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100560 low=testGen.args.tensor_shape_range[0],
561 high=testGen.args.tensor_shape_range[1],
562 size=1,
563 )
564 )[0]
565 # If N or H is large let b_oc be 1 to reduce output tensor size
566 if max(a_shape) > 1000:
567 b_oc = 1
568
569 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
570 return [a_shape, b_shape]
571
572 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100573 def tgConcat(testGen, rng, op, rank, error_name=None):
574 pl, const = op["operands"]
575 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100576
577 # Create extra tensors to concat.
578 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100579 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100580 shape_list = []
581 for i in range(pl + const + num_tensors):
582 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100583 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100584 wrongShape = shape.copy()
585
586 if remove and len(shape) > 1:
587 wrongShape = wrongShape[1:]
588 else:
589 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100590 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100591
592 shape_list.append(wrongShape)
593 else:
594 shape_list.append(shape.copy())
595
596 return shape_list
597
598 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100599 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100600 if error_name in [
601 ErrorIf.AxisSmallerZero,
602 ErrorIf.AxisLargerRank,
603 ErrorIf.ConcatInputRankMismatch,
604 ]:
605 return shapeList
606
607 # Split concat shape along axis to allow for multiple const inputs
608 # without making too many large tensors
609 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
610 # If axis can't be split we still need to invalidate other dimensions
611 if error_name == ErrorIf.ConcatInputDimMismatch:
612 for shape in shapeList[1:]:
613 # Negative test shapeLists are created individually for each test,
614 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100615 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100616 return shapeList
617
618 # Create copy of shape we are going to split (so we don't alter shapeList)
619 shape = shapeList[0].copy()
620 # Add original shape as first input
621 new_shapeList = [shape.copy()]
622 length_on_axis = shape[axis]
623 remaining_length = length_on_axis
624 for i in range(len(shapeList) - 2):
625 # Calculate split on axis and remaining value
626 split_shape_val = int(shape[axis] / 2)
627 remaining_length = remaining_length - split_shape_val
628
629 # Append new shape, and set remaining shape
630 shape[axis] = split_shape_val
631 new_shapeList.append(shape.copy())
632
633 # invalidate dimensions
634 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100635 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100636 else:
637 shape[axis] = remaining_length
638
639 if i == len(shapeList) - 3:
640 new_shapeList.append(shape.copy())
641
642 return new_shapeList
643
644
645class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100646 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100647
648 def __init__(self):
649 pass
650
Jeremy Johnson1271c442023-09-05 11:39:26 +0100651 class TVGInfo:
652 """Enhanced tensor values information including data gen dict."""
653
654 def __init__(self, tensorList, dataGenDict):
655 self.tensorList = tensorList
656 self.dataGenDict = dataGenDict
657
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100658 # Default high value for random numbers
659 TVG_FLOAT_HIGH_VALUE = {
660 DType.FP32: (1 << 128) - (1 << (127 - 23)),
661 DType.FP16: (1 << 16) - (1 << (15 - 10)),
662 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000663 DType.FP8E4M3: 448,
664 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100665 }
666
Jeremy Johnson30476252023-11-20 16:15:30 +0000667 # Default lowest normal values for random numbers
668 TVG_FLOAT_LOW_VALUE = {
669 DType.FP32: np.exp2(-126),
670 DType.FP16: np.exp2(-14),
671 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000672 DType.FP8E4M3: np.exp2(-9),
673 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000674 }
675
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100676 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100677 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # Return a tuple of (low,high) data range values for the given data
679 # type using a combination of per operator table limits, data limits
680 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000681 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100682 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000683 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000684 if lowValueLookup is not None and dtype in lowValueLookup:
685 low_val = lowValueLookup[dtype]
686 else:
687 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000688 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000689 # respecting the default ranges if more/less than the low/high
690 # values
691 data_range = (
692 max(low_val, type_range[0]),
693 min(high_val, type_range[1]),
694 )
695 if data_range[0] > data_range[1]:
696 # Invalid data range from low to high created due to user
697 # constraints revert to using internal ranges as they are
698 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000699 logger.info(
700 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
701 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000702 data_range = (low_val, high_val)
703 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000704 return None
705
706 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100707 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100708 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709 ):
710 # Variable inputs versus constants
711 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000712 if "p_count" in argsDict:
713 # Override for operators like CONCAT
714 pCount = argsDict["p_count"]
715 cCount = argsDict["c_count"]
716 assert pCount + cCount == len(
717 shapeList
718 ), "Placeholders & Constant tensors must match shapes list"
719
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000720 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100721
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100722 if (
723 error_name is not None
724 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100725 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100726 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000727 # Fall back to internal data gen when dealing with unsupported types or ops
728 data_range = argsDict["data_range"] if "data_range" in argsDict else None
729 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000730 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000731 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000732 if "data_range_list" in argsDict:
733 data_range = argsDict["data_range_list"][idx]["range"]
734 roundMode = (
735 "round" in argsDict["data_range_list"][idx]
736 and argsDict["data_range_list"][idx]["round"] is True
737 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000738 if data_range is not None and dtype not in (
739 DType.FP16,
740 DType.FP32,
741 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000742 DType.FP8E4M3,
743 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000744 ):
745 # Change from inclusive to exclusive range
746 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000747
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100748 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000749 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000750 if dtype == DType.SHAPE:
751 arr = np.int64(argsDict["fixed_data"][idx])
752 elif dtype == DType.INT8:
753 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000754 elif dtype == DType.INT16:
755 arr = np.int16(argsDict["fixed_data"][idx])
756 elif dtype == DType.INT32:
757 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000758 else:
759 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000760 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100761 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000762 if roundMode:
763 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000764 if idx < pCount:
765 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
766 else:
767 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100768
Jeremy Johnson1271c442023-09-05 11:39:26 +0100769 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
770
771 # Create data generator meta-data
772 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100773 tens_data = {
774 "version": "0.1",
775 "tensors": {},
776 }
777 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100778 for idx, shape in enumerate(shapeList):
779
780 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000781 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
782 tens_meta["generator"] = gtu.DataGenType(
783 gtu.DataGenType.FIXED_DATA
784 ).name
785 else:
786 tens_meta["generator"] = gtu.DataGenType(dg_type).name
787
Jeremy Johnson1271c442023-09-05 11:39:26 +0100788 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
789 tens_meta["shape"] = [int(i) for i in shape]
790 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100791 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100792
793 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100794 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100795 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100796 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100797
798 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
799 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000800 if (
801 tens_meta["generator"]
802 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
803 ):
804 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
805 tens_meta["fixed_data_info"] = info
806 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100807 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000808
Won Jeon64e4bfe2024-01-18 06:31:55 +0000809 data_range = None
810 if "data_range_list" in argsDict:
811 data_range = argsDict["data_range_list"][idx]["range"]
812 if "round" in argsDict["data_range_list"][idx]:
813 info["round"] = argsDict["data_range_list"][idx]["round"]
814 elif "data_range" in argsDict:
815 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000816
Won Jeon64e4bfe2024-01-18 06:31:55 +0000817 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100818 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000819 info["range"] = [str(v) for v in data_range]
820 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100821 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
822 info = {}
823 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100824 info["ks"] = int(argsDict["ks"])
825 if "acc_type" in argsDict:
826 # Convert type number into JSON name
827 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
828 "json"
829 ]
830 if "kernel" in argsDict:
831 info["kernel"] = [int(k) for k in argsDict["kernel"]]
832 if "axis" in argsDict:
833 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100834 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000835 elif dg_type == gtu.DataGenType.FULL_RANGE:
836 info = {}
837 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100838 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000839 )
840 tens_meta["full_range_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100841 else:
842 # TODO - other data gen type
843 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100844
845 # Using the finished generate config meta data - generate the data if
846 # needed and assign a tensor name from the serializer
847
848 # Need to generate data when not lazy or for the bias tensor as we need
849 # to work out if the bias data is non-zero for compliance
850 if not testGen.args.lazy_data_gen or (
851 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
852 ):
853 # Give this tensor a temporary name until we get one from the serializer
854 temp_name = f"placeholder_{idx}"
855 dg_tens_meta[temp_name] = tens_meta
856 # Create data now using the temporary name to access meta details
857 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000858 if tens_meta["data_type"] == "SHAPE":
859 # Tensor type SHAPE and Numpy file type must be the same
860 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100861 # Remove the item as we will give it the correct name later
862 del dg_tens_meta[temp_name]
863
864 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
865 # The KS value used by compliance verification is altered when the
866 # bias data is non-zero
867 if max(abs(data)) > 0.0:
868 argsDict["ksb"] = argsDict["ks"] + 1
869
870 if testGen.args.lazy_data_gen:
871 data = None
872
873 if tens_meta["input_type"] == "VARIABLE":
874 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
875 else:
876 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
877
878 tens_ser_list.append(tens)
879 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100880 dg_tens_meta[tens.name] = tens_meta
881
Jeremy Johnson1271c442023-09-05 11:39:26 +0100882 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
883
884 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100885 def tvgNegate(
886 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
887 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100888 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000889 # Integer test
890 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100891 pCount, cCount = op["operands"]
892 assert (
893 pCount == 1 and cCount == 0
894 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100895 # Must create tensors with values within accumulator (int32) negatable
896 # range
897 max_val = (1 << 31) - 1
898 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100899 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100900 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000902 tens_ser_list = []
903 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
905 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000906 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000908 # ERROR_IF or floating point test
909 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100910 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100911 )
912
Jeremy Johnson30476252023-11-20 16:15:30 +0000913 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000914 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
915 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
916 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
917 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
918 }
919
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100921 def tvgAddSub(
922 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
923 ):
Won Jeon74342e52024-01-09 00:34:40 +0000924 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000925 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000927 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100928 pCount, cCount = op["operands"]
929 assert (
930 pCount == 2 and cCount == 0
931 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000932 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000933 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
Jeremy Johnson32bf9012024-03-20 16:32:23 +0000934 data_range = None # Use default
935 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE):
936 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100937 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
938 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100939 if add:
940 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
941 else:
942 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
943
944 # Work out the saturation limits
945 max_i32 = (1 << 31) - 1
946 min_i32 = -(1 << 31)
947 max_arr = np.full(shapeList[1], max_i32)
948 min_arr = np.full(shapeList[1], min_i32)
949
950 # Find how much values exceed the maximum/minimums
951 sat_max_arr = np.maximum(res_arr - max_arr, 0)
952 sat_min_arr = np.minimum(res_arr - min_arr, 0)
953
954 if not add:
955 # Swap saturation values and negate values as we need to perform opposite operations
956 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
957
958 # Create new array of unsaturated values by clipping values as needed
959 b_unsat_arr = b_arr
960 if (sat_max_arr != 0).any():
961 # Clip values that cause saturation
962 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
963 # Reduce axes in unsaturated tensor to match original tensor
964 for axis, dim in enumerate(b_arr.shape):
965 if dim != b_unsat_arr.shape[axis]:
966 assert (
967 dim == 1
968 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
969 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
970
971 if (sat_min_arr != 0).any():
972 # Clip values that cause saturation
973 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
974 # Reduce axes in unsaturated tensor to match original tensor
975 for axis, dim in enumerate(b_arr.shape):
976 if dim != b_unsat_arr.shape[axis]:
977 assert (
978 dim == 1
979 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
980 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
981
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000982 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100983 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
984 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000985 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100986 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
987 )
988
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000989 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100990 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000991 # ERROR_IF or floating point test
992 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100993 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000994 )
995 if data_range:
996 argsDict["data_range"] = data_range
997
998 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100999 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001000 )
1001
1002 @staticmethod
1003 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001004 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001005 ):
1006 if dtypeList[0] in (
1007 DType.INT32,
1008 DType.INT16,
1009 DType.INT8,
1010 ):
1011 # Limit input tensors with cond_if_binary or while_loop to stop
1012 # saturation of add/sub ops with int32 and keep all logical shift
1013 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001014 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001015 pCount, cCount = op["operands"]
1016 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001017 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001018 for idx, shape in enumerate(shapeList[:]):
1019 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001020 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001022 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001023 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001024 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001025 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1026 )
1027 pRemain -= 1
1028 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001029 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001030 testGen.ser.addConst(shape, dtypeList[idx], arr)
1031 )
1032
Jeremy Johnson587cc842024-02-08 11:45:44 +00001033 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001034 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001035 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001036 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001037 )
1038
1039 @staticmethod
1040 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001041 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001042 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001043 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044 pCount, cCount = op["operands"]
1045 # Force value of operand[1] to be within [0, num_bits]
1046 assert (
1047 pCount == 2 and cCount == 0
1048 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1049
Jeremy Johnson587cc842024-02-08 11:45:44 +00001050 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001051 for idx, shape in enumerate(shapeList[:]):
1052 if idx == 1:
1053 if dtypeList[idx] == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001054 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001056 arr = np.int32(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001058 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001060 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001061 else:
1062 raise Exception("OpArithmeticRightShift: invalid input dtype")
1063 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001064 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001065 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001066
Jeremy Johnson587cc842024-02-08 11:45:44 +00001067 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001068
1069 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001070 def tvgReshape(
1071 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1072 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001073 dtypeList[1] = DType.SHAPE
1074 shapeList[1] = [len(argsDict["new_shape"])]
1075 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1076 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1077
1078 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001079 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001080 )
1081
1082 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001083 def tvgRescale(
1084 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1085 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001086 scale32 = argsDict["scale"]
1087 multiplier_arr = argsDict["multiplier"]
1088 shift_arr = argsDict["shift"]
1089
1090 if scale32:
1091 dtypeList[1] = DType.INT32
1092 else:
1093 dtypeList[1] = DType.INT16
1094 shapeList[1] = [len(multiplier_arr)]
1095 dtypeList[2] = DType.INT8
1096 shapeList[2] = [len(shift_arr)]
1097 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1098 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1099
1100 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001101 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001102 )
1103
1104 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001105 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001106 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1107 pad_values = argsDict["pad"].flatten()
1108 dtypeList[1] = DType.SHAPE
1109 shapeList[1] = [len(pad_values)]
1110 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1111 argsDict["fixed_data"] = [None, pad_values]
1112
1113 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001114 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001115 )
1116
1117 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001118 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001119 dtypeList[1] = DType.SHAPE
1120 shapeList[1] = [len(argsDict["start"])]
1121 dtypeList[2] = DType.SHAPE
1122 shapeList[2] = [len(argsDict["size"])]
1123 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1124 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1125
1126 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001127 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001128 )
1129
1130 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001131 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001132 dtypeList[1] = DType.SHAPE
1133 shapeList[1] = [len(argsDict["multiples"])]
1134 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1135
1136 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001137 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001138 )
1139
1140 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001141 def tvgSelect(
1142 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1143 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001144 # Set datatype of condition tensor to boolean
1145 dtypeList[0] = DType.BOOL
1146
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001147 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001148 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001149 )
1150
1151 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001152 def tvgIntDiv(
1153 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1154 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001156 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001157 pCount, cCount = op["operands"]
1158 assert (
1159 pCount == 2 and cCount == 0
1160 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1161
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001162 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001163
1164 # Two invalid cases for Op.INTDIV:
1165 # 1. divisor == 0
1166 # 2. dividend == -(1<<31) and divisor == -1
1167 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001168 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1169 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001170
1171 if (divisor_arr == 0).any():
1172 continue
1173
1174 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1175 continue
1176
1177 break
1178
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001179 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001180 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1181 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001182 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001183 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1184 )
1185
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001186 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001187 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001188 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001189 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001190 )
1191
Jeremy Johnson30476252023-11-20 16:15:30 +00001192 # Set the MUL data range to the square root of the largest value
1193 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001194 TVG_FLOAT_HIGH_VALUE_MUL = {
1195 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1196 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1197 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1198 }
1199
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001200 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001201 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001202 if error_name is not None or dtypeList[0] in (
1203 DType.FP16,
1204 DType.BF16,
1205 DType.FP32,
1206 ):
1207 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001208 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001209 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001210 )
1211 if data_range:
1212 argsDict["data_range"] = data_range
1213
Jeremy Johnson0a042992024-02-28 13:20:05 +00001214 if dtypeList[0] != DType.SHAPE:
1215 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1216 dtypeList[2] = DType.INT8
1217 shapeList[2] = [1]
1218 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1219 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1220
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001221 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001222 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001223 )
1224 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001225 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001226 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001227
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001228 tens_ser_list = []
1229
1230 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001231 if dtypeList[0] == DType.SHAPE:
1232 shift = 0
1233 else:
1234 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001235 if dtypeList[0] == DType.INT8:
1236 num_bits = 8
1237 elif dtypeList[0] == DType.INT16:
1238 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001239 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001240 num_bits = 32
1241 elif error_name == ErrorIf.WrongInputType:
1242 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001243 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001244 raise Exception(
1245 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1246 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001247
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001248 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001249 if dtypeList[idx] == DType.SHAPE:
1250 low = testGen.args.tensor_shape_range[0]
1251 high = testGen.args.tensor_shape_range[1]
1252 else:
1253 low = -(2 ** (num_bits - 1))
1254 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001255
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001256 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1257 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001258
1259 i = 0
1260 while True:
1261
1262 a_arr_64 = a_arr.astype(np.int64)
1263 b_arr_64 = b_arr.astype(np.int64)
1264
1265 if shift > 0:
1266 rounding = 1 << (shift - 1)
1267 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001268 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001269 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001270
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001271 if (result_arr > -(2**31)).all() and (
1272 result_arr <= ((2**31) - 1)
1273 ).all():
1274 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001275
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001276 i = i + 1
1277 a_arr = a_arr // 2
1278 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001279
Won Jeon74342e52024-01-09 00:34:40 +00001280 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001281 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001282 tens_ser_list.append(
1283 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1284 )
1285 tens_ser_list.append(
1286 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1287 )
1288 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001289 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001290 tens_ser_list.append(
1291 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1292 )
1293 tens_ser_list.append(
1294 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1295 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001296 tens_ser_list.append(
1297 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1298 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001299
1300 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001301
1302 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001303 def tvgConcat(
1304 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1305 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001306 count = len(shapeList) - testGen.args.num_const_inputs_concat
1307 if count < 1:
1308 count = 1
1309 if testGen.args.num_const_inputs_concat == 0:
1310 count = len(shapeList)
1311
Won Jeon74342e52024-01-09 00:34:40 +00001312 op = testGen.TOSA_OP_LIST[opName]
1313 if op["op"] == Op.CONCAT_SHAPE:
1314 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001315 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001316 else:
1317 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001318 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001319 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001320
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001321 # Override default pCount/cCount for operator
1322 argsDict["p_count"] = count
1323 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001324
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001325 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001326 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001327 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001328
1329 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001330 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001331 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001332 ):
1333 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001334 pCount, cCount = op["operands"]
1335 assert (
1336 pCount == 2 and cCount == 0
1337 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001338 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1339 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001340 tens_ser_list = []
1341 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001342 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1343 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001344 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001345 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1346 )
1347
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001348 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001349
1350 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001351 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001352 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1353 # Integer
1354 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001355 pCount, cCount = op["operands"]
1356 assert (
1357 pCount == 2 and cCount == 0
1358 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001359
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001360 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1361 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001362
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001363 # Using random numbers means that it will be very unlikely that
1364 # there are any matching (equal) values, therefore force that
1365 # there are twice the number of matching values as the tensor rank
1366 for num in range(0, len(shapeList[0]) * 2):
1367 a_index = []
1368 b_index = []
1369 # Choose an index in each axis for the whole shape
1370 for axis in range(0, len(shapeList[0])):
1371 # Index can be up to the largest dimension in both shapes
1372 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001373 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001374 )
1375 # Reduce the index down to a shape's dim for broadcasting
1376 a_index.append(min(shapeList[0][axis] - 1, index))
1377 b_index.append(min(shapeList[1][axis] - 1, index))
1378
1379 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1380
Jeremy Johnsona0150012023-11-15 15:52:06 +00001381 tens_ser_list = []
1382 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001383 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1384 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001385 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001386 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1387 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001388 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001389 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001390 # ERROR_IF or floating point test
1391 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001392 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001393 )
1394
1395 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001396 def tvgReduceSum(
1397 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1398 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001399 dtype = dtypeList[0]
1400 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001401 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001402 pCount, cCount = op["operands"]
1403 assert (
1404 pCount == 1 and cCount == 0
1405 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1406 # Limit values so that the sum cannot exceed the range of an int32 during
1407 # summation of any axis
1408 range_val = int((1 << 31) / max(shapeList[0]))
1409 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001410 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001411 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001412 tens_ser_list = []
1413 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001414 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001415 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001417 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001418 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001419 if (
1420 error_name is None
1421 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1422 ):
1423 # Limit ranges for (non error & non compliance) tests by using
1424 # values that can be summed on any axis to not hit infinity
1425 highval_lookup = {
1426 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1427 / max(shapeList[0])
1428 }
1429 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001430 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001431 )
1432 assert data_range is not None
1433 argsDict["data_range"] = data_range
1434
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001435 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001436 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001437 )
1438
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001439 @staticmethod
1440 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001441 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001442 ):
1443 dtype = dtypeList[0]
1444 if error_name is None:
1445 # Limit ranges for (non error) tests by using
1446 # values that can be multiplied on any axis to not hit infinity
1447 highval_lookup = {
1448 dtype: math.pow(
1449 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1450 1 / max(shapeList[0]),
1451 )
1452 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001453 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001454 assert data_range is not None
1455 argsDict["data_range"] = data_range
1456
1457 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001458 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001459 )
1460
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001461 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001462 def tvgResize(
1463 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1464 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001465 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001466 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001467 dtypeList[0],
1468 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1469 )
1470 if data_range:
1471 argsDict["data_range"] = data_range
1472 # Needed for compliance
1473 argsDict["max_abs_value"] = data_range[1]
1474
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001475 scale_values = argsDict["scale"]
1476 offset_values = argsDict["offset"]
1477 border_values = argsDict["border"]
1478 dtypeList[1] = DType.SHAPE
1479 dtypeList[2] = DType.SHAPE
1480 dtypeList[3] = DType.SHAPE
1481 shapeList[1] = [len(scale_values)]
1482 shapeList[2] = [len(offset_values)]
1483 shapeList[3] = [len(border_values)]
1484 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1485
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001486 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001487 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001488 )
1489
Jeremy Johnson30476252023-11-20 16:15:30 +00001490 # Set the POW exponent high data range
1491 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1492 DType.FP32: 10.0,
1493 DType.FP16: 10.0,
1494 DType.BF16: 10.0,
1495 }
1496 # POW highest base value (within a safe margin of error) that can be raised
1497 # to +ve exponent that doesn't become Infinity
1498 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1499 DType.FP32: math.floor(
1500 math.pow(
1501 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1502 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1503 )
1504 ),
1505 DType.FP16: math.floor(
1506 math.pow(
1507 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1508 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1509 )
1510 ),
1511 DType.BF16: math.floor(
1512 math.pow(
1513 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1514 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1515 )
1516 ),
1517 }
1518 # POW lowest base value (within a safe margin of error) that can be raised
1519 # to -ve exponent that doesn't become Infinity
1520 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1521 DType.FP32: math.ceil(
1522 math.pow(
1523 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1524 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1525 )
1526 * 1000
1527 )
1528 / 1000,
1529 DType.FP16: math.ceil(
1530 math.pow(
1531 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1532 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1533 )
1534 * 1000
1535 )
1536 / 1000,
1537 DType.BF16: math.ceil(
1538 math.pow(
1539 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1540 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1541 )
1542 * 1000
1543 )
1544 / 1000,
1545 }
1546
1547 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001548 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001549 if error_name is not None:
1550 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001551 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001552 )
1553 dtype = dtypeList[0]
1554 # Different ranges for POW
1555 test_set = argsDict["s"]
1556 if test_set == 0:
1557 # Positive base with fractional exponent
1558 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001559 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001560 dtype,
1561 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1562 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1563 )
1564 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001565 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001566 )
1567 exp_round = False
1568 else:
1569 # Integer exponent
1570 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001571 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001572 )
1573 exp_round = True
1574 if test_set == 1:
1575 # Positive base
1576 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001577 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001578 dtype,
1579 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1580 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1581 )
1582 else:
1583 assert test_set == 2
1584 # Negative base
1585 # Supply new look up tables with negative values
1586 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001587 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001588 dtype,
1589 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1590 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1591 )
1592
1593 data_range_list = (
1594 {
1595 "range": base_range,
1596 },
1597 {
1598 "range": exp_range,
1599 "round": exp_round,
1600 },
1601 )
1602 argsDict["data_range_list"] = data_range_list
1603 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001604 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001605 )
1606
1607 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001608 def tvgLogRsqrt(
1609 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1610 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001611 # LOG & RSQRT data range from lowest expressible positive number to
1612 # largest to avoid NaNs
1613 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001614 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001615 dtypeList[0],
1616 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1617 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1618 )
1619 if data_range:
1620 argsDict["data_range"] = data_range
1621
1622 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001623 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001624 )
1625
1626 # Set the EXP data range to the log of the largest to smallest values
1627 # to avoid infinities or making the result zero
1628 TVG_FLOAT_HIGH_VALUE_EXP = {
1629 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1630 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1631 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1632 }
1633 TVG_FLOAT_LOW_VALUE_EXP = {
1634 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1635 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1636 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1637 }
1638
1639 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001640 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001641 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001642 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001643 dtypeList[0],
1644 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1645 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1646 )
1647 if data_range:
1648 argsDict["data_range"] = data_range
1649
1650 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001651 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001652 )
1653
1654 @staticmethod
1655 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001656 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001657 ):
1658 dtype = dtypeList[0]
1659 if (
1660 error_name is None
1661 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001662 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001663 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001664 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001665 # Limit ranges for (non error & non compliance) FP tests by using
1666 # values that can be multiplied on any axis to not hit infinity/NaN
1667 IC = shapeList[0][1]
1668 highval_lookup = {
1669 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1670 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001671 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001672 assert data_range is not None
1673 argsDict["data_range"] = data_range
1674
1675 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001676 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001677 )
1678
Jeremy Johnson708da822023-11-15 16:25:45 +00001679 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001680 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001681 in_dtype = dtypeList[0]
1682 out_dtype = argsDict["out_type"]
1683 # Create look up to limit input tensor to output type maximums to avoid
1684 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001685 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001686 highval_lookup = {in_dtype: out_range[1]}
1687 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001688 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001689 in_dtype,
1690 highval_lookup,
1691 )
1692
1693 assert data_range is not None
1694 argsDict["data_range"] = data_range
1695
1696 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001697 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001698 )
1699
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001700 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001701 def tvgGather(
1702 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1703 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001704 K = shapeList[0][1]
1705
1706 # Fix the type of the indices tensor
1707 dtypeList[1] = DType.INT32
1708
1709 dtype = dtypeList[0]
1710 if not gtu.dtypeIsSupportedByCompliance(dtype):
1711 # Test unsupported by data generator
1712 op = testGen.TOSA_OP_LIST[opName]
1713 pCount, cCount = op["operands"]
1714 assert (
1715 pCount == 2 and cCount == 0
1716 ), "Op.GATHER must have 2 placeholders, 0 consts"
1717
1718 tens_ser_list = []
1719 for idx, shape in enumerate(shapeList):
1720 dtype = dtypeList[idx]
1721 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001722 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001723 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1724 else:
1725 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001726 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001727 # To match old functionality - create indices as CONST
1728 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1729
1730 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1731
1732 else:
1733 # ERROR_IF or floating point test
1734 # Use inclusive values upto index K for indices tensor
1735 data_range_list = (
1736 {"range": None},
1737 {"range": (0, K - 1)},
1738 )
1739 argsDict["data_range_list"] = data_range_list
1740
1741 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001742 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001743 )
1744
1745 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001746 def tvgScatter(
1747 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1748 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001749 K = shapeList[0][1]
1750 W = shapeList[2][1]
1751
1752 # Work out an indices tensor here with data that doesn't exceed the
1753 # dimension K of the values_in tensor and does NOT repeat the same K
1754 # location as needed by the spec:
1755 # "It is not permitted to repeat the same output index within a single
1756 # SCATTER operation and so each output index occurs at most once."
1757 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1758
1759 # Fix the type of the indices tensor
1760 dtypeList[1] = DType.INT32
1761
1762 dtype = dtypeList[0]
1763 if not gtu.dtypeIsSupportedByCompliance(dtype):
1764 # Test unsupported by data generator
1765 op = testGen.TOSA_OP_LIST[opName]
1766 pCount, cCount = op["operands"]
1767 assert (
1768 pCount == 3 and cCount == 0
1769 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1770
1771 tens_ser_list = []
1772 for idx, shape in enumerate(shapeList):
1773 dtype = dtypeList[idx]
1774 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001775 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001776 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1777 else:
1778 # Create the indices array
1779 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1780 arr = []
1781 for n in range(shape[0]):
1782 # Get a shuffled list of output indices (0 to K-1) and
1783 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001784 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001785 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1786 # To match old functionality - create indices as CONST
1787 tens_ser_list.append(
1788 testGen.ser.addConst(shape, dtype, indices_arr)
1789 )
1790
1791 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1792
1793 else:
1794 # ERROR_IF or floating point test
1795 # Use inclusive values upto index K for indices tensor
1796 data_range_list = (
1797 {"range": None},
1798 {"range": (0, K - 1)},
1799 {"range": None},
1800 )
1801 argsDict["data_range_list"] = data_range_list
1802
1803 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001804 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001805 )
1806
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001807
1808class TosaArgGen:
1809 """Argument generators create exhaustive or random lists of attributes for
1810 operators that take attributes or other parameters.
1811
1812 The return value is a list of (descriptive_name, [arglist]) tuples where
1813 the descriptive_name is appended to the test name and the arglist is expanded
1814 as arguments to the operator build function.
1815 """
1816
1817 def __init__(self):
1818 pass
1819
1820 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001821 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001822 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001823 if (
1824 error_name is None
1825 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1826 and gtu.dtypeIsSupportedByCompliance(dtype)
1827 ):
Tai Ly60dc48c2024-03-08 22:19:41 +00001828 if gtu.dtypeIsFloat(dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001829 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1830 else:
1831 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1832 else:
1833 # Error test or No data generator types listed - assume random
1834 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1835
1836 # Expand arg list with other data generator types
1837 new_arg_list = []
1838 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001839 for arg_str, args_dict in arg_list:
evacha019c96eef2024-02-07 11:21:55 +00001840
1841 if dg_type == gtu.DataGenType.FULL_RANGE:
1842 tensor_size = gtu.product(shapeList[0])
1843 if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1844 # Large enough tensor data size for full range, add a single test
1845 num_test_sets = 0
1846 else:
1847 # Not enough data size for full range of values, revert to random numbers
1848 dg_type = gtu.DataGenType.PSEUDO_RANDOM
1849
Jeremy Johnson1271c442023-09-05 11:39:26 +01001850 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001851 if error_name is None:
1852 num_test_sets = (
1853 args_dict["num_test_sets"]
1854 if "num_test_sets" in args_dict
1855 else 0
1856 )
1857 else:
evacha019c96eef2024-02-07 11:21:55 +00001858 # Add single test for pseudo random
Jeremy Johnson30476252023-11-20 16:15:30 +00001859 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001860
1861 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1862 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001863 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001864 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001865 shape_info = (
1866 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1867 if "shape" in args_dict
1868 else ""
1869 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001870 logger.info(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001871 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001872 )
1873 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001874 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001875 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001876 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001877
Jeremy Johnson30476252023-11-20 16:15:30 +00001878 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1879
1880 if num_test_sets > 0:
1881 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001882 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
1883 set_args_dict = args_dict.copy()
1884 set_args_dict["s"] = s
1885 set_args_dict["dg_type"] = dg_type
1886 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001887 else:
1888 # Default is a single test
evacha019c96eef2024-02-07 11:21:55 +00001889 new_args_dict = args_dict.copy()
1890 new_args_dict["dg_type"] = dg_type
1891 new_arg_list.append((arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001892
1893 return new_arg_list
1894
1895 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001896 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001897 """A trivial argument generator for operators that don't take any
1898 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001899 arg_list = TosaArgGen._add_data_generators(
1900 testGen,
1901 opName,
evacha019c96eef2024-02-07 11:21:55 +00001902 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001903 dtype,
1904 [("", {})],
1905 error_name,
1906 )
1907 # Return list of tuples: (arg_str, args_dict)
1908 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001909
1910 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001911 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001912 """Pow operator needs different test sets to cover random numbers
1913 without creating NaNs or Infs"""
1914 arg_list = TosaArgGen._add_data_generators(
1915 testGen,
1916 opName,
evacha019c96eef2024-02-07 11:21:55 +00001917 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001918 dtype,
1919 [("", {"num_test_sets": 3})],
1920 error_name,
1921 )
1922 # Return list of tuples: (arg_str, args_dict)
1923 return arg_list
1924
1925 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001926 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001927 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001928 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001929 shape = shapeList[0]
1930
1931 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001932 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001933 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001934 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001935 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001936 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001937 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001938 # Create tests for each dimension
1939 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001940
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001941 opid = testGen.TOSA_OP_LIST[opName]["op"]
1942
1943 for a in axes:
1944 args_dict = {"axis": int(a)}
1945 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001946 output_shape = shape.copy()
1947 if error_name is None:
1948 # It only matters that we calculate the dot_products correctly
1949 # for non error_if tests as they should never be run
1950 output_shape[a] = 1
1951 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001952 args_dict["shape"] = shape
1953 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1954 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1955
1956 arg_list.append(("axis{}".format(a), args_dict))
1957
1958 arg_list = TosaArgGen._add_data_generators(
1959 testGen,
1960 opName,
evacha019c96eef2024-02-07 11:21:55 +00001961 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001962 dtype,
1963 arg_list,
1964 error_name,
1965 )
1966 # Return list of tuples: (arg_str, args_dict)
1967 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001968
1969 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001970 def _calculate_sparsity(num_tests, sparsity_factor):
1971 sparsity = num_tests // sparsity_factor + 1
1972 # If there are only a small number of tests, just select them all
1973 if sparsity < 13:
1974 sparsity = 1
1975 # To get a variety of parameter combinations sparsity should not be a
1976 # multiple of 2, 3 or 5
1977 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1978 sparsity += 1
1979 return sparsity
1980
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001981 # Maximum number of error_if variants to produce
1982 MAX_CONV_ERROR_IFS = 3
1983
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001984 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001985 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001986 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001987 arg_list = []
1988
Jeremy Johnson0c716862023-04-13 17:18:19 +01001989 if testGen.args.level8k and error_name is not None:
1990 # Don't produce negative large tests
1991 return arg_list
1992
1993 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001994 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001995 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001996 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001997
Tai Lyf36f2562024-03-14 16:21:29 +00001998 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
1999
2000 if error_name == ErrorIf.WrongAccumulatorType:
2001 accum_dtypes = (
2002 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2003 )
James Ward8b390432022-08-12 20:48:56 +01002004
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002005 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01002006 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002007 depthwise = opName.startswith("depthwise")
2008
2009 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01002010 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002011 if error_name != ErrorIf.WrongRank:
2012 assert len(ifm_shape) == rank
2013 assert len(filter_shape) == rank
2014
Jeremy Johnson0c716862023-04-13 17:18:19 +01002015 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002016 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002017 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002018 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002019 # compliance size - KS
2020 k_size = gtu.product(k_shape)
2021 if not depthwise:
2022 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002023
Jeremy Johnson0c716862023-04-13 17:18:19 +01002024 if not testGen.args.level8k:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002025 if error_name in (
2026 ErrorIf.PadSmallerZero,
2027 ErrorIf.StrideSmallerOne,
2028 ErrorIf.DilationSmallerOne,
2029 ):
2030 # Use specific invalid value(s)
2031 if error_name == ErrorIf.PadSmallerZero:
2032 # Create negative paddings but with positive opposite paddings
2033 neg_pad = rng.choice(range(-5, 0))
2034 p_vals = [neg_pad, abs(neg_pad)]
2035 else:
2036 p_vals = [0, 0]
2037 if error_name == ErrorIf.StrideSmallerOne:
2038 # Can't use stride=0, as it is used to derive output shape, as a divisor
2039 s_vals = [rng.choice(range(-5, 0))]
2040 else:
2041 s_vals = [1]
2042 if error_name == ErrorIf.DilationSmallerOne:
2043 d_vals = [rng.choice(range(-5, 1))]
2044 else:
2045 d_vals = [1]
2046 p = p_vals * k_rank
2047 s = s_vals * k_rank
2048 d = d_vals * k_rank
2049
2050 # Fix values to produce valid error_if
2051 for index in range(k_rank):
2052 pad_offset = index * 2
2053 fixed = False
2054 while not fixed:
2055 partial = (
2056 ifm_shape[index + 1]
2057 - 1
2058 + p[pad_offset]
2059 + p[pad_offset + 1]
2060 - (k_shape[index] - 1) * d[index]
2061 )
2062 remainder = partial % s[index]
2063 if partial <= 0:
2064 p[pad_offset + 1] += abs(partial) + 1
2065 elif remainder:
2066 # Stride will be negative for StrideSmallerOne
2067 assert remainder < 0
2068 p[pad_offset + 1] += abs(remainder)
2069 else:
2070 fixed = True
2071 paddings = {tuple(p)}
2072 strides = {tuple(s)}
2073 dilations = {tuple(d)}
2074 logger.debug(f"agConv: error pad={p} stride={s} dilation={d}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01002075 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002076 # Generate comprehensive argument lists
Jeremy Johnson0c716862023-04-13 17:18:19 +01002077 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002078 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
Jeremy Johnson0c716862023-04-13 17:18:19 +01002079 # Stride must be greater than 1 to force non-integer error
2080 startStride = (
2081 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002082 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002083 s_vals = [
2084 x for x in range(startStride, testGen.args.max_conv_stride + 1)
2085 ]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002086 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002087
2088 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2089 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002090
Jeremy Johnson0c716862023-04-13 17:18:19 +01002091 if not error_name and testGen.args.oversize:
2092 # add some oversize argument values
2093 if max(ifm_shape) < 64:
2094 bigPadding = 9
2095 paddings.update(
2096 {
2097 x
2098 for x in itertools.product(
2099 *([[0, bigPadding]] * (k_rank * 2))
2100 )
2101 }
2102 )
2103 bigStride = 8
2104 strides.update(
2105 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2106 )
2107 bigDilation = 7
2108 dilations.update(
2109 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2110 )
2111 max_dim_size = None
2112
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002113 if error_name:
2114 # Cycle through all error_if tests but we only keep the first few
2115 sparsity = 1
2116 else:
2117 # There are too many parameter combinations, so generate them sparsely,
2118 sparsity_factor = 120
2119 sparsity = TosaArgGen._calculate_sparsity(
2120 len(paddings) * len(strides) * len(dilations), sparsity_factor
2121 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002122 else:
2123 # Only test 8k levels boundaries
2124 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2125 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2126 bigPadding = bigKernel
2127
2128 dilation_shape = [1] * k_rank
2129 pad_shape = [0] * k_rank * 2
2130 if conv3d:
2131 # Small stride apart from for big kernel (see below) to keep
2132 # tensor size/calculation small
2133 stride_shape = [1] * k_rank
2134 for idx in range(k_rank):
2135 pad_offset = idx * 2
2136 if k_shape[idx] == bigKernel:
2137 # Padding shape needs to account for tensor shape
2138 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2139 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2140 # Big stride to reduce output size
2141 stride_shape[idx] = bigKernel
2142 else:
2143 # Account for kernel size
2144 pad_shape[pad_offset] = k_shape[idx] - 1
2145 else:
2146 # Always have a large stride with extra padding and dilation to keep
2147 # tensor calculation reasonable
2148 stride_shape = [bigKernel] * k_rank
2149 for idx in range(k_rank):
2150 # Dilation shape must account for kernel size
2151 dilation_shape[idx] = bigKernel // k_shape[idx]
2152 # Padding shape needs to accommodate tensor/kernel & dilation
2153 pad_offset = idx * 2
2154 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2155 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2156
2157 strides = {tuple(stride_shape)}
2158 dilations = {tuple(dilation_shape)}
2159 paddings = {tuple(pad_shape)}
2160 # Create a limit for the output dimensions size
2161 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2162
2163 # Currently allow all combinations that are reasonable size
2164 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002165
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002166 more_tests = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002167 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002168 for a in accum_dtypes:
2169 for s in sorted(list(strides)):
2170 for p in sorted(list(paddings)):
2171 for d in sorted(list(dilations)):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002172 if (
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002173 more_tests
2174 and n % sparsity == 0
Tai Lyf36f2562024-03-14 16:21:29 +00002175 # the padded shape must exceed the dilation * kernel to get a positive
2176 # sized output shape
2177 and (ifm_shape[1] - 1 + p[0] + p[1])
2178 > d[0] * (k_shape[0] - 1)
2179 and (ifm_shape[2] - 1 + p[2] + p[3])
2180 > d[1] * (k_shape[1] - 1)
2181 and (
2182 k_rank < 3
2183 or (
2184 (ifm_shape[3] - 1 + p[4] + p[5])
2185 > d[2] * (k_shape[2] - 1)
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002186 )
2187 )
Tai Lyf36f2562024-03-14 16:21:29 +00002188 ):
2189 remainders = []
2190 outputs = []
2191 for index in range(k_rank):
2192 pad_offset = index * 2
2193 partial = (
2194 ifm_shape[index + 1]
2195 - 1
2196 + p[pad_offset]
2197 + p[pad_offset + 1]
2198 - (k_shape[index] - 1) * d[index]
2199 )
2200 remainders.append(partial % s[index])
2201 outputs.append((partial // s[index]) + 1)
2202
2203 if (
2204 # the parameters must produce integer exact output
2205 error_name != ErrorIf.ConvOutputShapeNonInteger
2206 and max(remainders) == 0
2207 ) or (
2208 error_name == ErrorIf.ConvOutputShapeNonInteger
2209 and max(remainders) > 0
2210 ):
2211 if (
2212 max_dim_size is not None
2213 and max(outputs) >= max_dim_size
2214 ):
2215 # Test will consume too much memory - skip it
2216 continue
2217
2218 # Compliance - number of dot product calculations
2219 if depthwise:
2220 # N*OH*OW*C*M
2221 dots = gtu.product(
2222 (ifm_shape[0], *outputs, *filter_shape[2:])
2223 )
2224 else:
2225 # N*OH*OW*OC or N*OD*OH*OW*OC
2226 dots = gtu.product(
2227 (ifm_shape[0], *outputs, filter_shape[0])
2228 )
2229 args_dict = {
2230 "acc_type": a,
2231 "stride": s,
2232 "pad": p,
2233 "dilation": d,
2234 "kernel": k_shape,
2235 "ks": k_size,
2236 "dot_products": dots,
2237 "shape": ifm_shape,
2238 }
2239
2240 # Support for larger values than 9 needs different delimiter
2241 delim = "" if max(s + p + d) <= 9 else "x"
2242 arg_list.append(
2243 (
2244 "acc{}_st{}_pad{}_dilat{}".format(
2245 testGen.typeStr(a),
2246 delim.join([str(x) for x in s]),
2247 delim.join([str(x) for x in p]),
2248 delim.join([str(x) for x in d]),
2249 ),
2250 args_dict,
2251 )
2252 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002253 if (
2254 error_name
2255 and len(arg_list) >= TosaArgGen.MAX_CONV_ERROR_IFS
2256 ):
2257 # Found enough errors
2258 logger.debug(
2259 f"Skipping creating more conv error tests for {error_name}"
2260 )
2261 more_tests = False
Tai Lyf36f2562024-03-14 16:21:29 +00002262 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002263
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002264 arg_list = TosaArgGen._add_data_generators(
2265 testGen,
2266 opName,
evacha019c96eef2024-02-07 11:21:55 +00002267 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002268 dtypes[0],
2269 arg_list,
2270 error_name,
2271 )
2272 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002273 return arg_list
2274
2275 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002276 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002277
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002278 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002279 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002280
2281 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002282 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002283 elif error_name == ErrorIf.WrongInputType:
2284 # Pick some potentially correct output dtype if input type is incorrect
2285 accum_dtype = DType.INT32
2286 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002287 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002288
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002289 # Set up compliance info
2290 args_dict = {
2291 "acc_type": accum_dtype,
2292 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2293 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2294 "shape": shapeList[0],
2295 }
2296
2297 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2298
2299 arg_list = TosaArgGen._add_data_generators(
2300 testGen,
2301 opName,
evacha019c96eef2024-02-07 11:21:55 +00002302 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002303 input_dtype,
2304 arg_list,
2305 error_name,
2306 )
2307 # Return list of tuples: (arg_str, args_dict)
2308 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002309
2310 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002311 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002312 # Get valid accumulate type(s)
2313 if dtype == DType.INT8:
2314 accum_dtypes = [DType.INT32]
2315 elif dtype == DType.INT16:
2316 accum_dtypes = [DType.INT48]
2317 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002318 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002319 elif dtype == DType.BF16:
2320 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002321 elif dtype == DType.FP32:
2322 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002323 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2324 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002325 elif error_name is None:
2326 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2327
2328 if error_name == ErrorIf.WrongOutputType:
2329 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002330 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002331 elif error_name == ErrorIf.WrongInputType:
2332 # Pick some potentially correct output dtype if input type is incorrect
2333 accum_dtypes = [DType.INT32]
2334
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002335 # Set up compliance info
2336 args_dict = {
2337 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2338 # Set dot_products = N*H*W
2339 "dot_products": gtu.product(
2340 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2341 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002342 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002343 }
2344
2345 # Create arg tuple of string and dict
2346 arg_list = []
2347 for a in accum_dtypes:
2348 d = args_dict.copy()
2349 d["acc_type"] = a
2350 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002351
2352 arg_list = TosaArgGen._add_data_generators(
2353 testGen,
2354 opName,
evacha019c96eef2024-02-07 11:21:55 +00002355 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002356 dtype,
2357 arg_list,
2358 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002359 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002360 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002361 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002362
2363 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002364 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002365 arg_list = []
2366
Jeremy Johnson0c716862023-04-13 17:18:19 +01002367 if testGen.args.level8k and error_name is not None:
2368 # Don't produce negative large tests
2369 return arg_list
2370
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002371 ifm_shape = shapeList[0]
2372 filter_shape = shapeList[1]
2373
Tai Lyf36f2562024-03-14 16:21:29 +00002374 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2375
2376 if error_name == ErrorIf.WrongAccumulatorType:
2377 accum_dtypes = (
2378 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2379 )
James Ward8b390432022-08-12 20:48:56 +01002380
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002381 # Must be rank 4
2382 if error_name != ErrorIf.WrongRank:
2383 assert len(ifm_shape) == 4
2384 assert len(filter_shape) == 4
2385
Jeremy Johnson0c716862023-04-13 17:18:19 +01002386 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002387 # compliance size - KS
2388 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002389
Jeremy Johnson0c716862023-04-13 17:18:19 +01002390 if not testGen.args.level8k:
2391 # Generate comprehensive argument lists
2392 # - except for named errors, which use specific invalid value(s)
2393 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2394 if error_name == ErrorIf.PadLargerEqualKernel:
2395 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002396 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002397 else:
2398 p_vals = [
2399 x
2400 for x in range(
2401 smallest_padding_size, testGen.args.max_conv_padding + 1
2402 )
2403 ]
2404 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2405 if error_name == ErrorIf.StrideSmallerOne:
2406 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002407 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002408 else:
2409 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2410 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002411
Jeremy Johnson0c716862023-04-13 17:18:19 +01002412 if not error_name and testGen.args.oversize:
2413 # add some oversize argument values
2414 if max(ifm_shape) < 64:
2415 bigPadding = 9
2416 paddings.update(
2417 {
2418 x
2419 for x in itertools.product(
2420 *([[smallest_padding_size, bigPadding]] * 4)
2421 )
2422 }
2423 )
2424 bigStride = 8
2425 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2426
2427 # There are too many parameter combinations, so generate them sparsely,
2428 # very sparse for negative tests
2429 sparsity_factor = 2 if error_name else 10
2430 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2431 # If there are only a small number of tests, just select them all
2432 if sparsity < 13:
2433 sparsity = 1
2434 # To get a variety of parameter combinations sparsity should not be a
2435 # multiple of 2, 3 or 5
2436 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2437 sparsity += 1
2438 else:
2439 # Only test 8k levels boundaries
2440 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2441 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2442 bigPadding = bigKernel
2443
2444 pad_shape = [0] * (len(k_shape) * 2)
2445 stride_shape = [1] * len(k_shape)
2446 # The point at which input dimension combined with the stride will
2447 # create large output sizes!
2448 LARGE_SIZE = 2
2449 for idx in range(len(k_shape)):
2450 pad_offset = idx * 2
2451 if k_shape[idx] == bigKernel:
2452 # Set large stride
2453 stride_shape[idx] = bigKernel
2454 # Use negative output padding to reduce shape size
2455 pad_shape[pad_offset] = -(bigPadding - 1)
2456 if ifm_shape[idx + 1] > LARGE_SIZE:
2457 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2458 else:
2459 # The other dimension should be the bigKernel
2460 alt_idx = 1 - idx
2461 if (
2462 k_shape[alt_idx] == bigKernel
2463 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2464 ):
2465 # As the input is small, the large stride won't
2466 # affect the output so we can add some padding
2467 pad_shape[pad_offset + 1] = bigPadding
2468
2469 strides = {tuple(stride_shape)}
2470 paddings = {tuple(pad_shape)}
2471
2472 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002473 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002474
2475 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002476 for a in accum_dtypes:
2477 for s in sorted(list(strides)):
2478 for p in sorted(list(paddings)):
2479 if n % sparsity == 0:
2480 # Determine the output shape
2481 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2482 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2483 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002484
Tai Lyf36f2562024-03-14 16:21:29 +00002485 # N*OH*OW*OC
2486 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2487 args_dict = {
2488 "acc_type": a,
2489 "stride": s,
2490 "pad": p,
2491 "kernel": k_shape,
2492 "ks": k_size,
2493 "dot_products": dots,
2494 "shape": ifm_shape,
2495 "out_shape": os,
2496 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002497
Tai Lyf36f2562024-03-14 16:21:29 +00002498 # Support for larger values than 9 needs different delimiter
2499 delim = "" if max(s + p) <= 9 else "x"
2500 arg_list.append(
2501 (
2502 "acc{}_st{}_pad{}_os{}".format(
2503 testGen.typeStr(a),
2504 delim.join([str(x) for x in s]),
2505 delim.join([str(x) for x in p]),
2506 "x".join([str(x) for x in os]),
2507 ),
2508 args_dict,
2509 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002510 )
Tai Lyf36f2562024-03-14 16:21:29 +00002511 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002512
Jeremy Johnson95a67102024-01-10 14:16:39 +00002513 arg_list = TosaArgGen._add_data_generators(
2514 testGen,
2515 opName,
evacha019c96eef2024-02-07 11:21:55 +00002516 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002517 dtypes[0],
2518 arg_list,
2519 error_name,
2520 )
2521 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002522 return arg_list
2523
2524 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002525 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002526 rank = len(shapeList[0])
2527
2528 # Exhaustively test combinations of padding on each side of each dimension
2529 # - the range of padding values is defined by pad_min and pad_max
2530 # - for padding >9, the name format needs to be more distinctive
2531 pad_min, pad_max = 0, 1
2532 pad_values = [x for x in range(pad_min, pad_max + 1)]
2533 if error_name == ErrorIf.PadSmallerZero:
2534 pad_values = [x for x in range(-2, 0)]
2535 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2536 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2537
2538 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002539 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002540 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002541 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002543 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002544 else:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002545 assert error_name == ErrorIf.WrongInputType
2546 pad_const_int = 0
2547 pad_const_fp = 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002548
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002549 list_shape_pad_values = list(shape_pad_values)
2550 # If we are producing tests for rank 6 or greater use sparsity
2551 if len(list_shape_pad_values) > 1024:
2552 sparsity_factor = 2 if error_name else 120
2553 sparsity = TosaArgGen._calculate_sparsity(
2554 len(list_shape_pad_values), sparsity_factor
2555 )
2556 else:
2557 sparsity = 1
2558
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002559 # Build arg list
2560 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002561 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002562 paddings = list(paddings)
2563 args_valid = True
2564
2565 if error_name == ErrorIf.PadSmallerZero:
2566 # Prevent negative output shapes while ensuring still testing for negative padding
2567 for i in range(rank):
2568 dim_after_padding = (
2569 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2570 )
2571 if dim_after_padding < 1:
2572 paddings[i] = (0, 0)
2573 if all([p > -1 for p in paddings[i]]):
2574 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002575 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002576 name = "pad"
2577 for r in range(rank):
2578 before, after = paddings[r]
2579 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002580 args_dict = {
2581 "pad": np.array(paddings),
2582 "pad_const_int": pad_const_int,
2583 "pad_const_fp": pad_const_fp,
2584 }
2585 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002586
2587 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002588 logger.debug(
2589 f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
2590 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002591
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002592 arg_list = TosaArgGen._add_data_generators(
2593 testGen,
2594 opName,
evacha019c96eef2024-02-07 11:21:55 +00002595 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002596 dtype,
2597 arg_list,
2598 error_name,
2599 )
2600
2601 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002602 return arg_list
2603
2604 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002605 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002606 arg_list = []
2607
2608 shape = shapeList[0]
2609 if error_name != ErrorIf.WrongRank:
2610 assert len(shape) == 4
2611
Jeremy Johnson0c716862023-04-13 17:18:19 +01002612 test_level8k = testGen.args.level8k and error_name is None
2613
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002614 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002615 startKernel = 2
2616 startPad = 0
2617 if not test_level8k:
2618 # Generate comprehensive argument lists
2619 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2620 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2621 # Stride must be greater than 1 to force non-integer error
2622 s_vals = [
2623 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2624 ]
2625 strides = {x for x in itertools.product(*([s_vals] * 2))}
2626 k_vals = [
2627 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2628 ]
2629 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2630 max_dim_size = None
2631 else:
2632 # Only test 8k levels
2633 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2634 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2635 strides = {(1, bigStride), (bigStride, 4)}
2636 kernels = {(1, bigKernel), (bigKernel, 3)}
2637 paddings = set()
2638 for s in sorted(list(strides)):
2639 for k in sorted(list(kernels)):
2640 padding = []
2641 for idx in range(len(k)):
2642 total_padding = s[idx] - shape[idx + 1] + k[idx]
2643 while total_padding < 0:
2644 # Must meet: shape + padding > kernel
2645 total_padding += s[idx]
2646 if total_padding < k[idx]:
2647 padding.extend([0, total_padding])
2648 else:
2649 # Note this may produce padding >= k[idx] which is not
2650 # allowed - but will be ignored in the creation loop below
2651 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2652 paddings.add(tuple(padding))
2653 # Create a limit for the output dimensions size
2654 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002655
James Ward8b390432022-08-12 20:48:56 +01002656 if opName == "max_pool2d":
2657 accum_dtypes = [None] # max_pool has no accumulate dtype
2658 elif dtype == DType.INT8 or dtype == DType.INT16:
2659 accum_dtypes = [DType.INT32]
2660 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002661 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002662 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002663 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002664 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2665 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002666 elif error_name is None:
2667 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2668 else:
2669 # Set to something for the ErrorIf case which has
2670 # incorrect input data-type
2671 accum_dtypes = [DType.INT32]
2672
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002673 if error_name == ErrorIf.WrongAccumulatorType:
2674 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2675
Jeremy Johnson0c716862023-04-13 17:18:19 +01002676 if not test_level8k:
2677 if testGen.args.oversize:
2678 # add some oversize argument values
2679 bigStride = 7
2680 bigKernel = 9
2681 strides.update(
2682 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002683 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002684 kernels.update(
2685 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2686 )
2687 if max(shape) < 64:
2688 # padding must be less than the kernel size
2689 bigPadding = bigKernel - 1
2690 paddings.update(
2691 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2692 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002693
Jeremy Johnson0c716862023-04-13 17:18:19 +01002694 # There are too many parameter combinations, so generate them sparsely,
2695 # very sparse for negative tests
2696 sparsity_factor = 2 if error_name else 500
2697 sparsity = (
2698 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2699 )
2700 else:
2701 # We have already limited test output combinations for 8k tests
2702 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002703
James Ward8b390432022-08-12 20:48:56 +01002704 arg_str = (
2705 "acc{}_st{}_kern{}_pad{}"
2706 if accum_dtypes[0] is not None
2707 else "st{}_kern{}_pad{}"
2708 )
2709
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002710 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002711 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002712 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002713
2714 # Support for larger values than 9 needs different delimiter
2715 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002716 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002717 delim.join([str(x) for x in stride]),
2718 delim.join([str(x) for x in kern]),
2719 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002720 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002721 args_dict = {
2722 "stride": stride,
2723 "pad": pad,
2724 "kernel": kern,
2725 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002726 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002727 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2728 }
James Ward8b390432022-08-12 20:48:56 +01002729
2730 if accum is not None:
2731 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002732 args_dict["acc_type"] = accum
2733 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002734
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002735 n = 0
James Ward8b390432022-08-12 20:48:56 +01002736 for a in accum_dtypes:
2737 for s in sorted(list(strides)):
2738 for p in sorted(list(paddings)):
2739 for k in sorted(list(kernels)):
2740 if error_name in [
2741 ErrorIf.StrideSmallerOne,
2742 ErrorIf.KernelSmallerOne,
2743 ErrorIf.PadSmallerZero,
2744 ErrorIf.PadLargerEqualKernel,
2745 ]:
2746 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002747 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002748 )
James Ward8b390432022-08-12 20:48:56 +01002749 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002750 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002751 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002752 )
James Ward8b390432022-08-12 20:48:56 +01002753 elif (
2754 n % sparsity == 0
2755 # padding must not exceed the kernel size
2756 and p[0] < k[0]
2757 and p[1] < k[0]
2758 and p[2] < k[1]
2759 and p[3] < k[1]
2760 # the padded shape must exceed the kernel size
2761 and (shape[1] + p[0] + p[1]) > k[0]
2762 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002763 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002764 partial_h = shape[1] + p[0] + p[1] - k[0]
2765 partial_w = shape[2] + p[2] + p[3] - k[1]
2766 remainder_h = partial_h % s[0]
2767 remainder_w = partial_w % s[1]
2768 output_h = partial_h // s[0] + 1
2769 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002770 logger.debug(
2771 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2772 )
James Ward8b390432022-08-12 20:48:56 +01002773 if (
2774 # the parameters must produce integer exact output
2775 error_name != ErrorIf.PoolingOutputShapeNonInteger
2776 and remainder_h == 0
2777 and remainder_w == 0
2778 ) or (
2779 error_name == ErrorIf.PoolingOutputShapeNonInteger
2780 and (remainder_h != 0 or remainder_w != 0)
2781 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002782 if (
2783 max_dim_size is not None
2784 and max(output_h, output_w) > max_dim_size
2785 ):
2786 # Test will consume too much memory - skip it
2787 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002788 # Dot products = N*OH*OW*C
2789 dp = gtu.product(
2790 (shape[0], output_h, output_w, shape[3])
2791 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002792 arg_list.append(
2793 get_arg_list_element(a, s, p, k, dp, shape)
2794 )
James Ward8b390432022-08-12 20:48:56 +01002795 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002796
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002797 # Now add data generator types
2798 arg_list = TosaArgGen._add_data_generators(
2799 testGen,
2800 opName,
evacha019c96eef2024-02-07 11:21:55 +00002801 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002802 dtype,
2803 arg_list,
2804 error_name,
2805 )
2806
2807 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002808 return arg_list
2809
2810 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002811 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002812 arg_list = []
2813
2814 # Enumerate the output types here
2815 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002816 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002817 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002818 dtypeList = [
2819 DType.BOOL,
2820 DType.INT16,
2821 DType.INT32,
2822 DType.FP16,
2823 DType.BF16,
2824 DType.FP32,
2825 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002826 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002827 dtypeList = [
2828 DType.BOOL,
2829 DType.INT8,
2830 DType.INT32,
2831 DType.FP16,
2832 DType.BF16,
2833 DType.FP32,
2834 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002835 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002836 dtypeList = [
2837 DType.BOOL,
2838 DType.INT8,
2839 DType.INT16,
2840 DType.FP16,
2841 DType.BF16,
2842 DType.FP32,
2843 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002844 elif inDtype == DType.BOOL:
2845 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002846 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002847 dtypeList = [
2848 DType.INT8,
2849 DType.INT16,
2850 DType.INT32,
2851 DType.FP32,
2852 DType.FP8E4M3,
2853 DType.FP8E5M2,
2854 ]
James Ward24dbc422022-10-19 12:20:31 +01002855 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002856 dtypeList = [
2857 DType.INT8,
2858 DType.INT16,
2859 DType.INT32,
2860 DType.FP32,
2861 DType.FP8E4M3,
2862 DType.FP8E5M2,
2863 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002864 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002865 dtypeList = [
2866 DType.INT8,
2867 DType.INT16,
2868 DType.INT32,
2869 DType.FP16,
2870 DType.BF16,
2871 DType.FP8E4M3,
2872 DType.FP8E5M2,
2873 ]
2874 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2875 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002876 elif error_name == ErrorIf.WrongInputType:
2877 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002878 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002879 else:
2880 raise Exception("Unexpected input dtype: {}".format(inDtype))
2881
2882 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002883 arg_list.append(
2884 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2885 )
2886
2887 # Now add data generator types
2888 arg_list = TosaArgGen._add_data_generators(
2889 testGen,
2890 opName,
evacha019c96eef2024-02-07 11:21:55 +00002891 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002892 dtype,
2893 arg_list,
2894 error_name,
2895 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002896
2897 return arg_list
2898
2899 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002900 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002901 arg_list = []
2902
2903 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002904 for outDtype in [
2905 DType.UINT8,
2906 DType.INT8,
2907 DType.INT16,
2908 DType.INT32,
2909 DType.UINT16,
2910 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002911 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002912 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002913 and error_name == ErrorIf.OutputZeroPointNotZero
2914 ):
2915 continue
2916 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002917 outDtype != DType.UINT16
2918 and error_name == ErrorIf.U16OutputZeroPointNotValid
2919 ) or (
2920 inDtype != DType.UINT16
2921 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002922 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002923 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002924 continue
2925 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002926 inDtype == DType.UINT8
2927 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002928 and error_name != ErrorIf.WrongOutputType
2929 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002930 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2931 continue
2932 if (
2933 inDtype not in [DType.INT8, DType.INT16]
2934 and outDtype == DType.UINT8
2935 and error_name != ErrorIf.WrongOutputType
2936 ):
2937 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2938 continue
2939 if (
2940 inDtype == DType.UINT16
2941 and outDtype != DType.INT16
2942 and error_name != ErrorIf.WrongOutputType
2943 ):
2944 # The only output dtype for UINT16 is INT16, skip all others
2945 continue
2946 if (
2947 inDtype != DType.INT16
2948 and outDtype == DType.UINT16
2949 and error_name != ErrorIf.WrongOutputType
2950 ):
2951 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002952 continue
2953 if (
2954 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002955 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002956 ):
2957 continue
2958
2959 for scale32 in [False, True]:
2960 if error_name == ErrorIf.ScaleTrue and not scale32:
2961 continue
2962 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2963 continue
2964 for double_round in [False, True]:
2965 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2966 continue
2967 for per_channel in [False, True]:
2968
2969 if (
2970 inDtype == DType.INT48
2971 and scale32
2972 and error_name != ErrorIf.ScaleTrue
2973 ):
2974 # Illegal condition. Must be scale32=False
2975 continue
2976 if (
2977 double_round
2978 and not scale32
2979 and error_name != ErrorIf.ScaleNotTrue
2980 ):
2981 # Illegal condition. ERROR_IF(!scale32 && double_round)
2982 continue
2983
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002984 if per_channel:
2985 nc = shapeList[0][-1]
2986 else:
2987 nc = 1
2988
2989 in_type_width = gtu.dtypeWidth(inDtype)
2990 out_type_width = gtu.dtypeWidth(outDtype)
2991
2992 # Calculate scale based on:
2993 # scale = a *(2^output_width)/(2^input_width))
2994
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002995 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002996 scale_arr = a * np.float32(
2997 (1 << out_type_width) / (1 << in_type_width)
2998 )
2999
3000 if scale32:
3001 # Cap the scaling at 2^31 - 1 for scale32
3002 scale_arr = np.clip(
3003 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
3004 )
3005 else:
3006 # Cap the scaling at 2^15 - 1 for scale16
3007 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
3008
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003009 logger.debug(
3010 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
3011 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003012
3013 multiplier_arr = np.int32(np.zeros(shape=[nc]))
3014 shift_arr = np.int32(np.zeros(shape=[nc]))
3015 for i in range(nc):
3016 (
3017 multiplier_arr[i],
3018 shift_arr[i],
3019 ) = TosaQuantGen.computeMultiplierAndShift(
3020 scale_arr[i], scale32
3021 )
3022
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003023 arg_list.append(
3024 (
3025 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01003026 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003027 int(scale32),
3028 int(double_round),
3029 int(per_channel),
3030 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00003031 {
3032 "output_dtype": outDtype,
3033 "scale": scale32,
3034 "double_round": double_round,
3035 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00003036 "multiplier": multiplier_arr,
3037 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003038 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003039 )
3040 )
3041
Jeremy Johnson587cc842024-02-08 11:45:44 +00003042 arg_list = TosaArgGen._add_data_generators(
3043 testGen,
3044 opName,
evacha019c96eef2024-02-07 11:21:55 +00003045 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003046 inDtype,
3047 arg_list,
3048 error_name,
3049 )
3050 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003051 return arg_list
3052
3053 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003054 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003055 arg_list = []
3056
3057 if dtype is DType.INT32:
3058 for p in range(testGen.args.num_rand_permutations):
3059
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003060 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003061 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003062 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003063 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003064
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003065 arg_list = TosaArgGen._add_data_generators(
3066 testGen,
3067 opName,
evacha019c96eef2024-02-07 11:21:55 +00003068 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003069 dtype,
3070 arg_list,
3071 error_name,
3072 )
3073 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003074 return arg_list
3075
3076 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003077 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003078 arg_list = []
3079
Jeremy Johnson587cc842024-02-08 11:45:44 +00003080 for round in (True, False):
3081 args_dict = {
3082 "round": round,
3083 }
3084 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003085
Jeremy Johnson587cc842024-02-08 11:45:44 +00003086 arg_list = TosaArgGen._add_data_generators(
3087 testGen,
3088 opName,
evacha019c96eef2024-02-07 11:21:55 +00003089 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003090 dtype,
3091 arg_list,
3092 error_name,
3093 )
3094 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003095 return arg_list
3096
Luke Hutton57287132023-02-06 14:54:18 +00003097 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003098 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003099 arg_list = []
3100
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003101 shape = shapeList[0]
3102 dot_products = gtu.product(shape)
3103 ks = 2 * shape[1] * shape[2] # 2*H*W
3104 for inverse in (True, False):
3105 args_dict = {
3106 "dot_products": dot_products,
3107 "shape": shape,
3108 "ks": ks,
3109 "acc_type": dtype,
3110 "inverse": inverse,
3111 }
3112 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003113
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003114 arg_list = TosaArgGen._add_data_generators(
3115 testGen,
3116 opName,
evacha019c96eef2024-02-07 11:21:55 +00003117 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003118 dtype,
3119 arg_list,
3120 error_name,
3121 )
3122 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003123 return arg_list
3124
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003125 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003126 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003127 arg_list = []
3128
3129 shape = shapeList[0]
3130 dot_products = gtu.product(shape)
3131 ks = shape[1] * shape[2] # H*W
3132 args_dict = {
3133 "dot_products": dot_products,
3134 "shape": shape,
3135 "ks": ks,
3136 "acc_type": dtype,
3137 }
3138 arg_list.append(("", args_dict))
3139
3140 arg_list = TosaArgGen._add_data_generators(
3141 testGen,
3142 opName,
evacha019c96eef2024-02-07 11:21:55 +00003143 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003144 dtype,
3145 arg_list,
3146 error_name,
3147 )
3148 # Return list of tuples: (arg_str, args_dict)
3149 return arg_list
3150
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003151 # Helper function for reshape. Gets some factors of a larger number.
3152 @staticmethod
3153 def getFactors(val, start=1):
3154 factors = []
3155
3156 for i in range(start, int(np.sqrt(val)) + 1):
3157 if (val % i) == 0:
3158 factors.append(i)
3159
3160 return factors
3161
3162 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003163 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003164 arg_list = []
3165
3166 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003167 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003168 factors = TosaArgGen.getFactors(totalElements)
3169
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003170 # Find new shapes up to the number of permutations asked for
3171 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003172 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003173 # Rank from 1 to MAX_TENSOR_RANK
3174 newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003175 if len(factors) < newRank:
3176 continue
3177
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003178 # escape_counter limits the generation of new shapes to a reasonable time
3179 for escape_counter in range(100):
3180
3181 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003182 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003183 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003184 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003185 for i in range(1, newRank):
3186 # pick rank-1 factors
3187 newShape.append(shuffledFactors[0])
3188 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003189 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003190 TosaArgGen.getFactors(remainingElements)
3191 )
3192 newShape.append(remainingElements)
3193
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003194 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003195 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003196 for name, args_dict in arg_list:
3197 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003198 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003199 break
3200
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003201 if not duplicate:
3202 outShape = "x".join([str(x) for x in newShape])
3203 arg_list.append(
3204 (
3205 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3206 {"new_shape": newShape},
3207 )
3208 )
3209 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003210 break
3211
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003212 # Now add data generator types
3213 arg_list = TosaArgGen._add_data_generators(
3214 testGen,
3215 opName,
evacha019c96eef2024-02-07 11:21:55 +00003216 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003217 dtype,
3218 arg_list,
3219 error_name,
3220 )
3221
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003222 return arg_list
3223
3224 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003225 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003226 arg_list = []
3227
3228 ifm_shape = shapeList[0]
3229
3230 if error_name == ErrorIf.IndexOutsideBounds:
3231 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3232 incorrect_small_index = range(-len(ifm_shape), 0)
3233 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3234 permutations.extend(
3235 [p for p in itertools.permutations(incorrect_small_index)]
3236 )
3237 elif error_name == ErrorIf.IndexUsedTwice:
3238 # Create list with a duplicated index
3239 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003240 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003241 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3242 permutations = [p for p in itertools.permutations(perm_range)]
3243
3244 else:
3245 # Get all permutations
3246 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3247
3248 # Limit to possible permutations from shape dimension or argument setting
3249 limit = min(len(permutations), testGen.args.num_rand_permutations)
3250
3251 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003252 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003253
3254 # Create list of required amount of permutations
3255 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003256 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003257 for p in range(limit)
3258 ]
evacha0198477222024-01-26 12:25:32 +00003259 # Now add data generator types
3260 arg_list = TosaArgGen._add_data_generators(
3261 testGen,
3262 opName,
evacha019c96eef2024-02-07 11:21:55 +00003263 shapeList,
evacha0198477222024-01-26 12:25:32 +00003264 dtype,
3265 arg_list,
3266 error_name,
3267 )
3268 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003269 return arg_list
3270
3271 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003272 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003273 arg_list = []
3274
3275 ifm_shape = shapeList[0]
3276 rank = len(ifm_shape)
3277
3278 for p in range(testGen.args.num_rand_permutations):
3279 start = []
3280 size = []
3281
3282 valid = True
3283
3284 for i in range(rank):
3285 if ifm_shape[i] > 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003286 start.append(rng.randInt(0, ifm_shape[i]))
3287 size.append(rng.randInt(0, ifm_shape[i] - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003288
3289 # Invalid slice size?
3290 if size[i] == 0:
3291 valid = False
3292 else:
3293 start.append(0)
3294 size.append(1)
3295
3296 if valid:
3297 # If ERROR_IF test required then incorrect start, size will be returned
3298 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003299 rng, error_name, ifm_shape, start, size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003300 )
evacha017f7d4252024-01-24 12:08:09 +00003301 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3302 # Now add data generator types
3303 arg_list = TosaArgGen._add_data_generators(
3304 testGen,
3305 opName,
evacha019c96eef2024-02-07 11:21:55 +00003306 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003307 dtype,
3308 arg_list,
3309 error_name,
3310 )
3311 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 return arg_list
3313
3314 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003315 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003316 arg_list = []
3317
3318 ifm_shape = shapeList[0]
3319 rank = len(ifm_shape)
3320
3321 for p in range(testGen.args.num_rand_permutations):
3322
3323 # Pick a few random, but small multiple values
3324 # because otherwise this has a tendency to generate
3325 # enormous tensors
3326 multiples = []
3327 for i in range(rank):
3328 if ifm_shape[i] > 1000:
3329 # Multiple of 1 if ifm_shape dimension is large to reduce
3330 # tensor size
3331 multiples.append(1)
3332 elif max(ifm_shape) > 1000:
3333 multiples.append(2)
3334 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003335 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003336 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003337
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003338 # Now add data generator types
3339 arg_list = TosaArgGen._add_data_generators(
3340 testGen,
3341 opName,
evacha019c96eef2024-02-07 11:21:55 +00003342 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003343 dtype,
3344 arg_list,
3345 error_name,
3346 )
3347 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003348 return arg_list
3349
3350 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003351 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003352 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003353 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003354
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003355 def get_aspect_ratio_resize_params():
3356 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003357 aspect_ratio = rng.choice(common_aspect_ratios)
3358 invert = rng.choice((False, True))
3359 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003360
3361 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3362 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3363 scale_y_d = scale_x_d = 1
3364 offset_x = offset_y = 0
3365
3366 if letterbox:
3367 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003368 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003369 border_x = 0
3370 else:
3371 # Pillarboxing
3372 border_y = 0
3373 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003374 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003375
3376 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3377 offset = (offset_y, offset_x)
3378 border = (border_y, border_x)
3379
3380 return scale, offset, border
3381
3382 def get_upscale_downscale_params():
3383 valid_params = False
3384 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003385 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003386
3387 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003388 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003389
3390 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003391 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003392 scale_x_d = scale_y_d = 1
3393 scale_x_n = scale_y_n = (
3394 1 << shift if origin_sampling else 2 << shift
3395 )
3396 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3397 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3398 else:
3399 scale_x_n = 1
3400 scale_y_n = 1
3401
3402 # Return list of valid scale_*_d values (max value 4) given input dim shape
3403 def get_valid_denom(ifm_dim):
3404 return [x for x in range(1, 5) if ifm_dim % x == 1]
3405
3406 # Generate list of valid downscale values and choose one randomly
3407 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3408 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3409
3410 if not valid_scale_y_ds and not valid_scale_x_ds:
3411 # Bad parameters, skip
3412 continue
3413
3414 if not valid_scale_y_ds:
3415 scale_y_d = 1
3416 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003417 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003418
3419 if not valid_scale_x_ds:
3420 scale_x_d = 1
3421 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003422 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003423
3424 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003425 offset_y = rng.randInt(0, 16 * scale_y_n)
3426 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003427 valid_params = True
3428
3429 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3430 offset = (offset_y, offset_x)
3431 border = (border_y, border_x)
3432 return scale, offset, border
3433
3434 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003435 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3436 scale = scale_n / scale_d
3437 if scale > max_scale:
3438 factor = scale / max_scale
3439 new_scale_d = math.ceil(scale_d * factor)
3440 assert scale_n / new_scale_d <= max_scale
3441 scale_d = new_scale_d
3442 return scale_d
3443
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003444 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003445 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3446 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003447
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003448 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3449 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003450
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003451 scale_y_d = fix_scale_to_max_scale(
3452 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3453 )
3454 scale_x_d = fix_scale_to_max_scale(
3455 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3456 )
3457
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003458 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003459 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3460 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3461 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3462 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003463
3464 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3465 offset = (offset_y, offset_x)
3466 border = (border_y, border_x)
3467 return scale, offset, border
3468
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003469 def get_level_8k_params():
3470 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003471 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003472 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3473 )
3474 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3475 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003476 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003477 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003478 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003479 if switch:
3480 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3481 else:
3482 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3483
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003484 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3485 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003486 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003487 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3488 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003489 border = (border_y, border_x)
3490 return scale, offset, border
3491
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003492 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003493 # Exclude illegal {mode, type} configurations. Pick legal output types
3494 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3495 outputDTypeList = [DType.INT8]
3496 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3497 outputDTypeList = [DType.INT16]
3498 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3499 outputDTypeList = [DType.INT32]
3500 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3501 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003502 elif dtype == DType.FP16:
3503 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003504 elif dtype == DType.BF16:
3505 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003506 elif dtype == DType.FP32:
3507 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003508 elif dtype == DType.FP8E4M3:
3509 outputDTypeList = [DType.FP8E4M3]
3510 elif dtype == DType.FP8E5M2:
3511 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003512 elif error_name == ErrorIf.WrongInputType:
3513 # If an incorrect input type is used then we set a 'correct'
3514 # output type to avoid other errors
3515 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3516 else:
3517 continue
3518
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003519 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3520
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003521 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003522 perm = 0
3523 while perm < testGen.args.num_rand_permutations:
3524 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003525 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003526 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003527 (
3528 get_rand_params,
3529 get_upscale_downscale_params,
3530 get_aspect_ratio_resize_params,
3531 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003532 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003533 scale, offset, border = _rnd_param_fn()
3534 else:
3535 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003536
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003537 # Expand params for bounds-checking
3538 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3539 (offset_y, offset_x) = offset
3540 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003541
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003542 # Make sure output dimensions OH and OW are integers
3543 partial_output_y = (
3544 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3545 )
3546 partial_output_x = (
3547 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3548 )
3549 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003550 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003551 if (
3552 partial_output_y % scale_y_d == 0
3553 and partial_output_x % scale_x_d == 0
3554 ):
3555 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003556 if perm > 0:
3557 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003558 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003559 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003560 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003561 while partial_output_y % scale_y_d != 0:
3562 scale_y_d -= 1
3563 while partial_output_x % scale_x_d != 0:
3564 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003565 # Make sure we are still within max scaling
3566 if (
3567 scale_y_n / scale_y_d
3568 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3569 scale_x_n / scale_x_d
3570 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3571 # Skip the test as it is using too large a scaling factor
3572 if perm > 0:
3573 perm += 1
3574 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003575
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003576 output_y = partial_output_y // scale_y_d + 1
3577 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003578
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003579 if (
3580 output_y >= testGen.args.max_resize_output_dim
3581 or output_x >= testGen.args.max_resize_output_dim
3582 ) and error_name is None:
3583 # Skip positive test if output dim will be too high
3584 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003585 if not testGen.args.level8k or perm > 0:
3586 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003587 continue
3588
3589 if (
3590 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003591 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003592 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003593 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003594 ):
3595 # Output dimensions out of scope
3596 if error_name is not None and perm > 0:
3597 # As long as we have one ERROR_IF test, don't worry
3598 # about creating all the other permutations
3599 perm += 1
3600 continue
3601
3602 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3603 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003604 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003605 and output_y - scale_y_d < 1
3606 )
3607 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003608 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003609 and output_x - scale_x_d < 1
3610 )
3611 ):
3612 # Can't create a negative test with these params as it
3613 # will create invalid output size
3614 if perm > 0:
3615 perm += 1
3616 continue
3617
3618 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3619 offset = [offset_y, offset_x]
3620 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003621
3622 # Common for all data types
3623 if error_name is not None:
3624 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003625 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003626 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003627 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003628 outputDTypeNew,
3629 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003630 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003631 error_name,
3632 mode,
3633 dtype,
3634 shapeList,
3635 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003636 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003638 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 )
3640 else:
3641 outputDTypeNew = outputDType
3642
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003643 arg_to_append = (
3644 arg_str.format(
3645 "N" if mode == ResizeMode.NEAREST else "B",
3646 testGen.typeStr(outputDTypeNew),
3647 scale[0],
3648 scale[1],
3649 scale[2],
3650 scale[3],
3651 offset[0],
3652 offset[1],
3653 border[0],
3654 border[1],
3655 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003656 {
3657 "mode": mode,
3658 "scale": scale,
3659 "offset": offset,
3660 "border": border,
3661 "output_dtype": outputDTypeNew,
3662 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003663 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003664 if arg_to_append in arg_list:
3665 # Skip already generated test params
3666 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003667
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003668 # Valid permutation
3669 perm += 1
3670 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003671
3672 # Now add data generator types
3673 arg_list = TosaArgGen._add_data_generators(
3674 testGen,
3675 opName,
evacha019c96eef2024-02-07 11:21:55 +00003676 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003677 dtype,
3678 arg_list,
3679 error_name,
3680 )
3681 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003682 return arg_list
3683
3684 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003685 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003686 arg_list = []
3687
3688 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003689 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003690 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003691 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003692 # Make sure all slopes are within REQUIRE min/max 16-bit int
3693 for idx in range(len(table) - 1):
3694 slope = table[idx + 1] - table[idx]
3695 # Alter the next table entry to force the slope to be ok
3696 if slope > 32767:
3697 table[idx + 1] -= slope - 32767
3698 if slope < -32768:
3699 table[idx + 1] -= slope + 32768
3700 slope = table[idx + 1] - table[idx]
3701 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003702 arg_list.append(
3703 (
3704 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003705 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003706 )
3707 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003708 # Now add data generator types
3709 arg_list = TosaArgGen._add_data_generators(
3710 testGen,
3711 opName,
evacha019c96eef2024-02-07 11:21:55 +00003712 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003713 dtype,
3714 arg_list,
3715 error_name,
3716 )
3717 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003718 return arg_list
3719
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003720 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003721 # CondIf generates the condition values here.
3722 # Convert to tensors in the build function, along with the
3723 # then and else blocks
3724 arg_list = []
3725
3726 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003727 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003728
Jeremy Johnson587cc842024-02-08 11:45:44 +00003729 # Now add data generator types
3730 arg_list = TosaArgGen._add_data_generators(
3731 testGen,
3732 opName,
evacha019c96eef2024-02-07 11:21:55 +00003733 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003734 dtype,
3735 arg_list,
3736 error_name,
3737 )
3738 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003739 return arg_list
3740
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003741 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003742 # While loop: 0 iterations, 1, more than 1
3743 arg_list = []
3744
Jeremy Johnson587cc842024-02-08 11:45:44 +00003745 for iterations in [0, 1, 4]:
3746 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003747
Jeremy Johnson587cc842024-02-08 11:45:44 +00003748 # Now add data generator types
3749 arg_list = TosaArgGen._add_data_generators(
3750 testGen,
3751 opName,
evacha019c96eef2024-02-07 11:21:55 +00003752 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003753 dtype,
3754 arg_list,
3755 error_name,
3756 )
3757 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003758 return arg_list