blob: ffa368370308f295d71b50d108b6ec30baf80140 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
262 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100263 shape_list = []
264
265 # Choose one of the inputs to broadcast
266 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100267 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
268 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000269
Jeremy Johnson0a042992024-02-28 13:20:05 +0000270 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100271 shape_bcast = shape.copy()
272
Jerry Ge135c9552023-05-23 20:59:32 +0000273 # To test broadcasting, the chosen fuzz index dimension should not be 1
274 if shape_bcast[fuzz_idx] == 1:
275 shape_bcast[fuzz_idx] += 1
276
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100277 # If the chosen input, pick a random index to broadcast
278 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000279 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100280 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100281 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100282 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100283 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 )
285 if rank != 1:
286 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100287 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000289 elif error_name == ErrorIf.BroadcastShapesMismatch:
290 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 else:
292 shape_bcast[fuzz_idx] = 1
293
294 shape_list.append(shape_bcast)
295
296 return shape_list
297
298 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100299 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000300 pl, const = op["operands"]
301 num_shapes = pl + const
302 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100303 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000304 )
305
306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100307 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000308 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 shape_list = TosaTensorGen._get_broadcast_shapes(
310 testGen, rng, 2, rank, error_name
311 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000312 # Add a single dimension tensor for shift
313 shape_list.append([1])
314 return shape_list
315
316 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100317 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318 pl, const = op["operands"]
319
320 if error_name != ErrorIf.WrongRank:
321 assert rank == 4
322
323 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000325 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326
327 # Constrict the overall size of the shape when creating ERROR_IF tests
328 if error_name:
329 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
330 ifm_shape, max_dim=24, max_items=10000
331 )
332
333 # Get the filter height/width from the operator parameters
334 filter_hw = op["filter"]
335
336 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100337 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # The filter dimensions are OHWI
340 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
341
342 # The bias is OC
343 bias_shape = np.asarray([ofm_depth])
344
345 return [ifm_shape, filter_shape, bias_shape]
346
347 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100348 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 pl, const = op["operands"]
350
351 if error_name != ErrorIf.WrongRank:
352 assert rank == 5
353
354 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100355 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000356 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100357
358 # Constrict the overall size of the shape when creating ERROR_IF tests
359 if error_name:
360 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
361 ifm_shape, max_dim=24, max_items=10000
362 )
363
364 # Get the filter depth/height/width from the operator parameters
365 filter_dhw = op["filter"]
366
367 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100368 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100369
370 # The filter dimensions are ODHWI
371 filter_shape = np.asarray(
372 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
373 )
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_channel])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100381 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386
387 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100388 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000389 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100390
391 # Constrict the overall size of the shape when creating ERROR_IF tests
392 if error_name:
393 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
394 ifm_shape, max_dim=24, max_items=10000
395 )
396
397 # Get the filter height/width from the operator parameters
398 filter_hw = op["filter"]
399
400 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100401 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100402
403 # The filter dimensions are OHWI
404 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
405
406 # The bias is OC
407 bias_shape = np.asarray([ofm_depth])
408
409 return [ifm_shape, filter_shape, bias_shape]
410
411 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100412 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100413 pl, const = op["operands"]
414
415 if error_name != ErrorIf.WrongRank:
416 assert rank == 4
417 assert pl == 1 and const == 2
418
419 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100420 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000421 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
426 ifm_shape, max_dim=24, max_items=10000
427 )
428
429 # Get the filter height/width from the operator parameters
430 # Filter is KH, HW, C, M
431 filter_hw = op["filter"]
432
433 # Generate a random OFM depth, but don't let it get too big because
434 # the output depth is M * C
435 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100436 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100437 ) + 1
438
439 # The filter dimensions are HWCM
440 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
441
442 # The bias is M * C
443 bias_shape = np.asarray([ifm_shape[3] * filter_m])
444
445 return [ifm_shape, filter_shape, bias_shape]
446
447 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100448 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 2 and const == 0
454
455 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100456 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 inc_h = 2 if ifm_shape[1] == 1 else 1
469 inc_w = 2 if ifm_shape[2] == 1 else 1
470 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100471 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000472 ifm_shape[1] += selected_inc[0]
473 ifm_shape[2] += selected_inc[1]
474
475 ifm_shape = testGen.constrictBatchSize(ifm_shape)
476
477 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
478 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100479 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000480 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100481 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000482 ifm_shapes[modify_shape][modify_dim] *= 2
483
484 return [ifm_shapes[0], ifm_shapes[1]]
485
486 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100487 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000488 pl, const = op["operands"]
489
490 if error_name != ErrorIf.WrongRank:
491 assert rank == 3
492 assert pl == 1 and const == 0
493
494 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100495 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000496
497 # Select nearest lower power of two from input height and width
498 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
499 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
500
501 # Constrict the overall size of the shape when creating ERROR_IF tests
502 if error_name:
503 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
504
505 # Generate an invalid kernel that is not a power of two
506 if error_name == ErrorIf.KernelNotPowerOfTwo:
507 # We must increment by 2 if current size is 1
508 inc_h = 2 if ifm_shape[1] == 1 else 1
509 inc_w = 2 if ifm_shape[2] == 1 else 1
510 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100511 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000512 ifm_shape[1] += selected_inc[0]
513 ifm_shape[2] += selected_inc[1]
514
James Ward30124a82023-02-02 14:56:33 +0000515 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000516
517 return [ifm_shape]
518
519 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100520 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100521 pl, const = op["operands"]
522
523 if error_name != ErrorIf.WrongRank:
524 assert rank == 2
525
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527
528 # Constrict the overall size of the shape when creating ERROR_IF tests
529 if error_name:
530 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
531
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100532 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100533 low=testGen.args.tensor_shape_range[0],
534 high=testGen.args.tensor_shape_range[1],
535 size=1,
536 )[0]
537 filter_shape = np.asarray([filter_oc, input_shape[1]])
538
539 bias_shape = np.asarray([filter_oc])
540
541 return [input_shape, filter_shape, bias_shape]
542
543 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100544 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100545 pl, const = op["operands"]
546
547 if error_name != ErrorIf.WrongRank:
548 assert rank == 3
549 assert pl == 2 and const == 0
550
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100551 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100552
553 # Constrict the overall size of the shape when creating ERROR_IF tests
554 if error_name:
555 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
556
557 # Get a random number for b_oc even if target shape is defined
558 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100559 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100560 low=testGen.args.tensor_shape_range[0],
561 high=testGen.args.tensor_shape_range[1],
562 size=1,
563 )
564 )[0]
565 # If N or H is large let b_oc be 1 to reduce output tensor size
566 if max(a_shape) > 1000:
567 b_oc = 1
568
569 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
570 return [a_shape, b_shape]
571
572 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100573 def tgConcat(testGen, rng, op, rank, error_name=None):
574 pl, const = op["operands"]
575 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100576
577 # Create extra tensors to concat.
578 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100579 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100580 shape_list = []
581 for i in range(pl + const + num_tensors):
582 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100583 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100584 wrongShape = shape.copy()
585
586 if remove and len(shape) > 1:
587 wrongShape = wrongShape[1:]
588 else:
589 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100590 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100591
592 shape_list.append(wrongShape)
593 else:
594 shape_list.append(shape.copy())
595
596 return shape_list
597
598 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100599 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100600 if error_name in [
601 ErrorIf.AxisSmallerZero,
602 ErrorIf.AxisLargerRank,
603 ErrorIf.ConcatInputRankMismatch,
604 ]:
605 return shapeList
606
607 # Split concat shape along axis to allow for multiple const inputs
608 # without making too many large tensors
609 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
610 # If axis can't be split we still need to invalidate other dimensions
611 if error_name == ErrorIf.ConcatInputDimMismatch:
612 for shape in shapeList[1:]:
613 # Negative test shapeLists are created individually for each test,
614 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100615 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100616 return shapeList
617
618 # Create copy of shape we are going to split (so we don't alter shapeList)
619 shape = shapeList[0].copy()
620 # Add original shape as first input
621 new_shapeList = [shape.copy()]
622 length_on_axis = shape[axis]
623 remaining_length = length_on_axis
624 for i in range(len(shapeList) - 2):
625 # Calculate split on axis and remaining value
626 split_shape_val = int(shape[axis] / 2)
627 remaining_length = remaining_length - split_shape_val
628
629 # Append new shape, and set remaining shape
630 shape[axis] = split_shape_val
631 new_shapeList.append(shape.copy())
632
633 # invalidate dimensions
634 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100635 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100636 else:
637 shape[axis] = remaining_length
638
639 if i == len(shapeList) - 3:
640 new_shapeList.append(shape.copy())
641
642 return new_shapeList
643
644
645class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100646 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100647
648 def __init__(self):
649 pass
650
Jeremy Johnson1271c442023-09-05 11:39:26 +0100651 class TVGInfo:
652 """Enhanced tensor values information including data gen dict."""
653
654 def __init__(self, tensorList, dataGenDict):
655 self.tensorList = tensorList
656 self.dataGenDict = dataGenDict
657
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100658 # Default high value for random numbers
659 TVG_FLOAT_HIGH_VALUE = {
660 DType.FP32: (1 << 128) - (1 << (127 - 23)),
661 DType.FP16: (1 << 16) - (1 << (15 - 10)),
662 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000663 DType.FP8E4M3: 448,
664 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100665 }
666
Jeremy Johnson30476252023-11-20 16:15:30 +0000667 # Default lowest normal values for random numbers
668 TVG_FLOAT_LOW_VALUE = {
669 DType.FP32: np.exp2(-126),
670 DType.FP16: np.exp2(-14),
671 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000672 DType.FP8E4M3: np.exp2(-9),
673 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000674 }
675
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100676 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100677 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # Return a tuple of (low,high) data range values for the given data
679 # type using a combination of per operator table limits, data limits
680 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000681 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100682 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000683 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000684 if lowValueLookup is not None and dtype in lowValueLookup:
685 low_val = lowValueLookup[dtype]
686 else:
687 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000688 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000689 # respecting the default ranges if more/less than the low/high
690 # values
691 data_range = (
692 max(low_val, type_range[0]),
693 min(high_val, type_range[1]),
694 )
695 if data_range[0] > data_range[1]:
696 # Invalid data range from low to high created due to user
697 # constraints revert to using internal ranges as they are
698 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000699 logger.info(
700 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
701 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000702 data_range = (low_val, high_val)
703 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000704 return None
705
706 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100707 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100708 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709 ):
710 # Variable inputs versus constants
711 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000712 if "p_count" in argsDict:
713 # Override for operators like CONCAT
714 pCount = argsDict["p_count"]
715 cCount = argsDict["c_count"]
716 assert pCount + cCount == len(
717 shapeList
718 ), "Placeholders & Constant tensors must match shapes list"
719
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000720 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100721
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100722 if (
723 error_name is not None
724 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100725 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100726 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000727 # Fall back to internal data gen when dealing with unsupported types or ops
728 data_range = argsDict["data_range"] if "data_range" in argsDict else None
729 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000730 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000731 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000732 if "data_range_list" in argsDict:
733 data_range = argsDict["data_range_list"][idx]["range"]
734 roundMode = (
735 "round" in argsDict["data_range_list"][idx]
736 and argsDict["data_range_list"][idx]["round"] is True
737 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000738 if data_range is not None and dtype not in (
739 DType.FP16,
740 DType.FP32,
741 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000742 DType.FP8E4M3,
743 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000744 ):
745 # Change from inclusive to exclusive range
746 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000747
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100748 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000749 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000750 if dtype == DType.SHAPE:
751 arr = np.int64(argsDict["fixed_data"][idx])
752 elif dtype == DType.INT8:
753 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000754 elif dtype == DType.INT16:
755 arr = np.int16(argsDict["fixed_data"][idx])
756 elif dtype == DType.INT32:
757 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000758 else:
759 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000760 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100761 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000762 if roundMode:
763 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000764 if idx < pCount:
765 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
766 else:
767 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100768
Jeremy Johnson1271c442023-09-05 11:39:26 +0100769 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
770
771 # Create data generator meta-data
772 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100773 tens_data = {
774 "version": "0.1",
775 "tensors": {},
776 }
777 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100778 for idx, shape in enumerate(shapeList):
779
780 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000781 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
782 tens_meta["generator"] = gtu.DataGenType(
783 gtu.DataGenType.FIXED_DATA
784 ).name
785 else:
786 tens_meta["generator"] = gtu.DataGenType(dg_type).name
787
Jeremy Johnson1271c442023-09-05 11:39:26 +0100788 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
789 tens_meta["shape"] = [int(i) for i in shape]
790 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100791 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100792
793 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100794 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100795 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100796 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100797
798 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
799 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000800 if (
801 tens_meta["generator"]
802 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
803 ):
804 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
805 tens_meta["fixed_data_info"] = info
806 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100807 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000808
Won Jeon64e4bfe2024-01-18 06:31:55 +0000809 data_range = None
810 if "data_range_list" in argsDict:
811 data_range = argsDict["data_range_list"][idx]["range"]
812 if "round" in argsDict["data_range_list"][idx]:
813 info["round"] = argsDict["data_range_list"][idx]["round"]
814 elif "data_range" in argsDict:
815 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000816
Won Jeon64e4bfe2024-01-18 06:31:55 +0000817 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100818 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000819 info["range"] = [str(v) for v in data_range]
820 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100821 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
822 info = {}
823 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100824 info["ks"] = int(argsDict["ks"])
825 if "acc_type" in argsDict:
826 # Convert type number into JSON name
827 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
828 "json"
829 ]
830 if "kernel" in argsDict:
831 info["kernel"] = [int(k) for k in argsDict["kernel"]]
832 if "axis" in argsDict:
833 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100834 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000835 elif dg_type == gtu.DataGenType.FULL_RANGE:
836 info = {}
837 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100838 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000839 )
840 tens_meta["full_range_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100841 else:
842 # TODO - other data gen type
843 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100844
845 # Using the finished generate config meta data - generate the data if
846 # needed and assign a tensor name from the serializer
847
848 # Need to generate data when not lazy or for the bias tensor as we need
849 # to work out if the bias data is non-zero for compliance
850 if not testGen.args.lazy_data_gen or (
851 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
852 ):
853 # Give this tensor a temporary name until we get one from the serializer
854 temp_name = f"placeholder_{idx}"
855 dg_tens_meta[temp_name] = tens_meta
856 # Create data now using the temporary name to access meta details
857 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000858 if tens_meta["data_type"] == "SHAPE":
859 # Tensor type SHAPE and Numpy file type must be the same
860 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100861 # Remove the item as we will give it the correct name later
862 del dg_tens_meta[temp_name]
863
864 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
865 # The KS value used by compliance verification is altered when the
866 # bias data is non-zero
867 if max(abs(data)) > 0.0:
868 argsDict["ksb"] = argsDict["ks"] + 1
869
870 if testGen.args.lazy_data_gen:
871 data = None
872
873 if tens_meta["input_type"] == "VARIABLE":
874 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
875 else:
876 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
877
878 tens_ser_list.append(tens)
879 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100880 dg_tens_meta[tens.name] = tens_meta
881
Jeremy Johnson1271c442023-09-05 11:39:26 +0100882 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
883
884 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100885 def tvgNegate(
886 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
887 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100888 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000889 # Integer test
890 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100891 pCount, cCount = op["operands"]
892 assert (
893 pCount == 1 and cCount == 0
894 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100895 # Must create tensors with values within accumulator (int32) negatable
896 # range
897 max_val = (1 << 31) - 1
898 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100899 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100900 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000902 tens_ser_list = []
903 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
905 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000906 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000908 # ERROR_IF or floating point test
909 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100910 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100911 )
912
Jeremy Johnson30476252023-11-20 16:15:30 +0000913 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000914 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
915 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
916 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
917 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
918 }
919
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100921 def tvgAddSub(
922 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
923 ):
Won Jeon74342e52024-01-09 00:34:40 +0000924 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000925 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000927 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100928 pCount, cCount = op["operands"]
929 assert (
930 pCount == 2 and cCount == 0
931 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000932 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000933 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
934 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100935 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
936 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100937 if add:
938 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
939 else:
940 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
941
942 # Work out the saturation limits
943 max_i32 = (1 << 31) - 1
944 min_i32 = -(1 << 31)
945 max_arr = np.full(shapeList[1], max_i32)
946 min_arr = np.full(shapeList[1], min_i32)
947
948 # Find how much values exceed the maximum/minimums
949 sat_max_arr = np.maximum(res_arr - max_arr, 0)
950 sat_min_arr = np.minimum(res_arr - min_arr, 0)
951
952 if not add:
953 # Swap saturation values and negate values as we need to perform opposite operations
954 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
955
956 # Create new array of unsaturated values by clipping values as needed
957 b_unsat_arr = b_arr
958 if (sat_max_arr != 0).any():
959 # Clip values that cause saturation
960 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
961 # Reduce axes in unsaturated tensor to match original tensor
962 for axis, dim in enumerate(b_arr.shape):
963 if dim != b_unsat_arr.shape[axis]:
964 assert (
965 dim == 1
966 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
967 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
968
969 if (sat_min_arr != 0).any():
970 # Clip values that cause saturation
971 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
972 # Reduce axes in unsaturated tensor to match original tensor
973 for axis, dim in enumerate(b_arr.shape):
974 if dim != b_unsat_arr.shape[axis]:
975 assert (
976 dim == 1
977 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
978 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
979
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000980 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100981 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
982 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000983 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100984 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
985 )
986
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000987 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100988 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000989 # ERROR_IF or floating point test
990 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100991 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000992 )
993 if data_range:
994 argsDict["data_range"] = data_range
995
996 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100997 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100998 )
999
1000 @staticmethod
1001 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001002 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001003 ):
1004 if dtypeList[0] in (
1005 DType.INT32,
1006 DType.INT16,
1007 DType.INT8,
1008 ):
1009 # Limit input tensors with cond_if_binary or while_loop to stop
1010 # saturation of add/sub ops with int32 and keep all logical shift
1011 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001012 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001013 pCount, cCount = op["operands"]
1014 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001015 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001016 for idx, shape in enumerate(shapeList[:]):
1017 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001018 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001020 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001022 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001023 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1024 )
1025 pRemain -= 1
1026 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001027 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001028 testGen.ser.addConst(shape, dtypeList[idx], arr)
1029 )
1030
Jeremy Johnson587cc842024-02-08 11:45:44 +00001031 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001032 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001033 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001034 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001035 )
1036
1037 @staticmethod
1038 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001039 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001040 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001041 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001042 pCount, cCount = op["operands"]
1043 # Force value of operand[1] to be within [0, num_bits]
1044 assert (
1045 pCount == 2 and cCount == 0
1046 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1047
Jeremy Johnson587cc842024-02-08 11:45:44 +00001048 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001049 for idx, shape in enumerate(shapeList[:]):
1050 if idx == 1:
1051 if dtypeList[idx] == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001052 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001053 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001054 arr = np.int32(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001056 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001058 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059 else:
1060 raise Exception("OpArithmeticRightShift: invalid input dtype")
1061 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001062 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001063 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064
Jeremy Johnson587cc842024-02-08 11:45:44 +00001065 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001066
1067 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001068 def tvgReshape(
1069 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1070 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001071 dtypeList[1] = DType.SHAPE
1072 shapeList[1] = [len(argsDict["new_shape"])]
1073 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1074 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1075
1076 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001077 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001078 )
1079
1080 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001081 def tvgRescale(
1082 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1083 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001084 scale32 = argsDict["scale"]
1085 multiplier_arr = argsDict["multiplier"]
1086 shift_arr = argsDict["shift"]
1087
1088 if scale32:
1089 dtypeList[1] = DType.INT32
1090 else:
1091 dtypeList[1] = DType.INT16
1092 shapeList[1] = [len(multiplier_arr)]
1093 dtypeList[2] = DType.INT8
1094 shapeList[2] = [len(shift_arr)]
1095 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1096 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1097
1098 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001099 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001100 )
1101
1102 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001103 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001104 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1105 pad_values = argsDict["pad"].flatten()
1106 dtypeList[1] = DType.SHAPE
1107 shapeList[1] = [len(pad_values)]
1108 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1109 argsDict["fixed_data"] = [None, pad_values]
1110
1111 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001112 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001113 )
1114
1115 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001116 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001117 dtypeList[1] = DType.SHAPE
1118 shapeList[1] = [len(argsDict["start"])]
1119 dtypeList[2] = DType.SHAPE
1120 shapeList[2] = [len(argsDict["size"])]
1121 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1122 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1123
1124 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001125 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001126 )
1127
1128 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001129 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001130 dtypeList[1] = DType.SHAPE
1131 shapeList[1] = [len(argsDict["multiples"])]
1132 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1133
1134 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001135 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001136 )
1137
1138 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001139 def tvgSelect(
1140 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1141 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001142 # Set datatype of condition tensor to boolean
1143 dtypeList[0] = DType.BOOL
1144
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001145 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001146 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001147 )
1148
1149 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001150 def tvgIntDiv(
1151 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1152 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001153 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001154 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 pCount, cCount = op["operands"]
1156 assert (
1157 pCount == 2 and cCount == 0
1158 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1159
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001160 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161
1162 # Two invalid cases for Op.INTDIV:
1163 # 1. divisor == 0
1164 # 2. dividend == -(1<<31) and divisor == -1
1165 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001166 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1167 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001168
1169 if (divisor_arr == 0).any():
1170 continue
1171
1172 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1173 continue
1174
1175 break
1176
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001177 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001178 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1179 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001180 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001181 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1182 )
1183
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001184 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001185 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001186 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001187 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001188 )
1189
Jeremy Johnson30476252023-11-20 16:15:30 +00001190 # Set the MUL data range to the square root of the largest value
1191 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001192 TVG_FLOAT_HIGH_VALUE_MUL = {
1193 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1194 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1195 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1196 }
1197
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001198 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001199 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001200 if error_name is not None or dtypeList[0] in (
1201 DType.FP16,
1202 DType.BF16,
1203 DType.FP32,
1204 ):
1205 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001206 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001207 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001208 )
1209 if data_range:
1210 argsDict["data_range"] = data_range
1211
Jeremy Johnson0a042992024-02-28 13:20:05 +00001212 if dtypeList[0] != DType.SHAPE:
1213 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1214 dtypeList[2] = DType.INT8
1215 shapeList[2] = [1]
1216 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1217 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1218
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001219 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001220 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001221 )
1222 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001223 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001224 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001225
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001226 tens_ser_list = []
1227
1228 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001229 if dtypeList[0] == DType.SHAPE:
1230 shift = 0
1231 else:
1232 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001233 if dtypeList[0] == DType.INT8:
1234 num_bits = 8
1235 elif dtypeList[0] == DType.INT16:
1236 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001237 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001238 num_bits = 32
1239 elif error_name == ErrorIf.WrongInputType:
1240 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001241 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001242 raise Exception(
1243 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1244 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001245
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001246 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001247 if dtypeList[idx] == DType.SHAPE:
1248 low = testGen.args.tensor_shape_range[0]
1249 high = testGen.args.tensor_shape_range[1]
1250 else:
1251 low = -(2 ** (num_bits - 1))
1252 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001253
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001254 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1255 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001256
1257 i = 0
1258 while True:
1259
1260 a_arr_64 = a_arr.astype(np.int64)
1261 b_arr_64 = b_arr.astype(np.int64)
1262
1263 if shift > 0:
1264 rounding = 1 << (shift - 1)
1265 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001266 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001267 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001268
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001269 if (result_arr > -(2**31)).all() and (
1270 result_arr <= ((2**31) - 1)
1271 ).all():
1272 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001273
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001274 i = i + 1
1275 a_arr = a_arr // 2
1276 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001277
Won Jeon74342e52024-01-09 00:34:40 +00001278 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001279 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001280 tens_ser_list.append(
1281 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1282 )
1283 tens_ser_list.append(
1284 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1285 )
1286 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001287 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001288 tens_ser_list.append(
1289 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1290 )
1291 tens_ser_list.append(
1292 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1293 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001294 tens_ser_list.append(
1295 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1296 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001297
1298 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001299
1300 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001301 def tvgConcat(
1302 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1303 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001304 count = len(shapeList) - testGen.args.num_const_inputs_concat
1305 if count < 1:
1306 count = 1
1307 if testGen.args.num_const_inputs_concat == 0:
1308 count = len(shapeList)
1309
Won Jeon74342e52024-01-09 00:34:40 +00001310 op = testGen.TOSA_OP_LIST[opName]
1311 if op["op"] == Op.CONCAT_SHAPE:
1312 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001313 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001314 else:
1315 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001316 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001317 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001319 # Override default pCount/cCount for operator
1320 argsDict["p_count"] = count
1321 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001322
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001323 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001324 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001325 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001326
1327 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001328 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001329 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001330 ):
1331 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001332 pCount, cCount = op["operands"]
1333 assert (
1334 pCount == 2 and cCount == 0
1335 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001336 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1337 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001338 tens_ser_list = []
1339 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001340 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1341 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001342 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001343 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1344 )
1345
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001346 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001347
1348 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001349 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001350 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1351 # Integer
1352 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001353 pCount, cCount = op["operands"]
1354 assert (
1355 pCount == 2 and cCount == 0
1356 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001357
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001358 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1359 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001360
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001361 # Using random numbers means that it will be very unlikely that
1362 # there are any matching (equal) values, therefore force that
1363 # there are twice the number of matching values as the tensor rank
1364 for num in range(0, len(shapeList[0]) * 2):
1365 a_index = []
1366 b_index = []
1367 # Choose an index in each axis for the whole shape
1368 for axis in range(0, len(shapeList[0])):
1369 # Index can be up to the largest dimension in both shapes
1370 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001371 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001372 )
1373 # Reduce the index down to a shape's dim for broadcasting
1374 a_index.append(min(shapeList[0][axis] - 1, index))
1375 b_index.append(min(shapeList[1][axis] - 1, index))
1376
1377 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1378
Jeremy Johnsona0150012023-11-15 15:52:06 +00001379 tens_ser_list = []
1380 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001381 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1382 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001383 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001384 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1385 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001386 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001387 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001388 # ERROR_IF or floating point test
1389 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001390 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001391 )
1392
1393 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001394 def tvgReduceSum(
1395 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1396 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001397 dtype = dtypeList[0]
1398 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001399 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001400 pCount, cCount = op["operands"]
1401 assert (
1402 pCount == 1 and cCount == 0
1403 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1404 # Limit values so that the sum cannot exceed the range of an int32 during
1405 # summation of any axis
1406 range_val = int((1 << 31) / max(shapeList[0]))
1407 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001408 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001409 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001410 tens_ser_list = []
1411 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001412 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001413 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001414 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001415 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001417 if (
1418 error_name is None
1419 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1420 ):
1421 # Limit ranges for (non error & non compliance) tests by using
1422 # values that can be summed on any axis to not hit infinity
1423 highval_lookup = {
1424 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1425 / max(shapeList[0])
1426 }
1427 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001428 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001429 )
1430 assert data_range is not None
1431 argsDict["data_range"] = data_range
1432
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001433 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001434 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001435 )
1436
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001437 @staticmethod
1438 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001439 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001440 ):
1441 dtype = dtypeList[0]
1442 if error_name is None:
1443 # Limit ranges for (non error) tests by using
1444 # values that can be multiplied on any axis to not hit infinity
1445 highval_lookup = {
1446 dtype: math.pow(
1447 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1448 1 / max(shapeList[0]),
1449 )
1450 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001451 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001452 assert data_range is not None
1453 argsDict["data_range"] = data_range
1454
1455 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001456 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001457 )
1458
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001459 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001460 def tvgResize(
1461 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1462 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001463 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001464 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001465 dtypeList[0],
1466 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1467 )
1468 if data_range:
1469 argsDict["data_range"] = data_range
1470 # Needed for compliance
1471 argsDict["max_abs_value"] = data_range[1]
1472
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001473 scale_values = argsDict["scale"]
1474 offset_values = argsDict["offset"]
1475 border_values = argsDict["border"]
1476 dtypeList[1] = DType.SHAPE
1477 dtypeList[2] = DType.SHAPE
1478 dtypeList[3] = DType.SHAPE
1479 shapeList[1] = [len(scale_values)]
1480 shapeList[2] = [len(offset_values)]
1481 shapeList[3] = [len(border_values)]
1482 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1483
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001484 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001485 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001486 )
1487
Jeremy Johnson30476252023-11-20 16:15:30 +00001488 # Set the POW exponent high data range
1489 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1490 DType.FP32: 10.0,
1491 DType.FP16: 10.0,
1492 DType.BF16: 10.0,
1493 }
1494 # POW highest base value (within a safe margin of error) that can be raised
1495 # to +ve exponent that doesn't become Infinity
1496 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1497 DType.FP32: math.floor(
1498 math.pow(
1499 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1500 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1501 )
1502 ),
1503 DType.FP16: math.floor(
1504 math.pow(
1505 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1506 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1507 )
1508 ),
1509 DType.BF16: math.floor(
1510 math.pow(
1511 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1512 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1513 )
1514 ),
1515 }
1516 # POW lowest base value (within a safe margin of error) that can be raised
1517 # to -ve exponent that doesn't become Infinity
1518 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1519 DType.FP32: math.ceil(
1520 math.pow(
1521 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1522 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1523 )
1524 * 1000
1525 )
1526 / 1000,
1527 DType.FP16: math.ceil(
1528 math.pow(
1529 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1530 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1531 )
1532 * 1000
1533 )
1534 / 1000,
1535 DType.BF16: math.ceil(
1536 math.pow(
1537 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1538 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1539 )
1540 * 1000
1541 )
1542 / 1000,
1543 }
1544
1545 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001546 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001547 if error_name is not None:
1548 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001549 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001550 )
1551 dtype = dtypeList[0]
1552 # Different ranges for POW
1553 test_set = argsDict["s"]
1554 if test_set == 0:
1555 # Positive base with fractional exponent
1556 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001557 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001558 dtype,
1559 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1560 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1561 )
1562 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001563 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001564 )
1565 exp_round = False
1566 else:
1567 # Integer exponent
1568 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001569 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001570 )
1571 exp_round = True
1572 if test_set == 1:
1573 # Positive base
1574 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001575 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001576 dtype,
1577 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1578 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1579 )
1580 else:
1581 assert test_set == 2
1582 # Negative base
1583 # Supply new look up tables with negative values
1584 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001585 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001586 dtype,
1587 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1588 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1589 )
1590
1591 data_range_list = (
1592 {
1593 "range": base_range,
1594 },
1595 {
1596 "range": exp_range,
1597 "round": exp_round,
1598 },
1599 )
1600 argsDict["data_range_list"] = data_range_list
1601 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001602 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001603 )
1604
1605 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001606 def tvgLogRsqrt(
1607 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1608 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001609 # LOG & RSQRT data range from lowest expressible positive number to
1610 # largest to avoid NaNs
1611 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001612 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001613 dtypeList[0],
1614 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1615 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1616 )
1617 if data_range:
1618 argsDict["data_range"] = data_range
1619
1620 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001621 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001622 )
1623
1624 # Set the EXP data range to the log of the largest to smallest values
1625 # to avoid infinities or making the result zero
1626 TVG_FLOAT_HIGH_VALUE_EXP = {
1627 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1628 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1629 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1630 }
1631 TVG_FLOAT_LOW_VALUE_EXP = {
1632 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1633 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1634 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1635 }
1636
1637 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001638 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001639 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001640 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001641 dtypeList[0],
1642 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1643 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1644 )
1645 if data_range:
1646 argsDict["data_range"] = data_range
1647
1648 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001649 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001650 )
1651
1652 @staticmethod
1653 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001654 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001655 ):
1656 dtype = dtypeList[0]
1657 if (
1658 error_name is None
1659 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001660 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001661 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001662 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001663 # Limit ranges for (non error & non compliance) FP tests by using
1664 # values that can be multiplied on any axis to not hit infinity/NaN
1665 IC = shapeList[0][1]
1666 highval_lookup = {
1667 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1668 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001669 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001670 assert data_range is not None
1671 argsDict["data_range"] = data_range
1672
1673 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001674 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001675 )
1676
Jeremy Johnson708da822023-11-15 16:25:45 +00001677 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001678 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001679 in_dtype = dtypeList[0]
1680 out_dtype = argsDict["out_type"]
1681 # Create look up to limit input tensor to output type maximums to avoid
1682 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001683 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001684 highval_lookup = {in_dtype: out_range[1]}
1685 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001686 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001687 in_dtype,
1688 highval_lookup,
1689 )
1690
1691 assert data_range is not None
1692 argsDict["data_range"] = data_range
1693
1694 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001695 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001696 )
1697
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001698 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001699 def tvgGather(
1700 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1701 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001702 K = shapeList[0][1]
1703
1704 # Fix the type of the indices tensor
1705 dtypeList[1] = DType.INT32
1706
1707 dtype = dtypeList[0]
1708 if not gtu.dtypeIsSupportedByCompliance(dtype):
1709 # Test unsupported by data generator
1710 op = testGen.TOSA_OP_LIST[opName]
1711 pCount, cCount = op["operands"]
1712 assert (
1713 pCount == 2 and cCount == 0
1714 ), "Op.GATHER must have 2 placeholders, 0 consts"
1715
1716 tens_ser_list = []
1717 for idx, shape in enumerate(shapeList):
1718 dtype = dtypeList[idx]
1719 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001720 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001721 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1722 else:
1723 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001724 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001725 # To match old functionality - create indices as CONST
1726 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1727
1728 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1729
1730 else:
1731 # ERROR_IF or floating point test
1732 # Use inclusive values upto index K for indices tensor
1733 data_range_list = (
1734 {"range": None},
1735 {"range": (0, K - 1)},
1736 )
1737 argsDict["data_range_list"] = data_range_list
1738
1739 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001740 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001741 )
1742
1743 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001744 def tvgScatter(
1745 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1746 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001747 K = shapeList[0][1]
1748 W = shapeList[2][1]
1749
1750 # Work out an indices tensor here with data that doesn't exceed the
1751 # dimension K of the values_in tensor and does NOT repeat the same K
1752 # location as needed by the spec:
1753 # "It is not permitted to repeat the same output index within a single
1754 # SCATTER operation and so each output index occurs at most once."
1755 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1756
1757 # Fix the type of the indices tensor
1758 dtypeList[1] = DType.INT32
1759
1760 dtype = dtypeList[0]
1761 if not gtu.dtypeIsSupportedByCompliance(dtype):
1762 # Test unsupported by data generator
1763 op = testGen.TOSA_OP_LIST[opName]
1764 pCount, cCount = op["operands"]
1765 assert (
1766 pCount == 3 and cCount == 0
1767 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1768
1769 tens_ser_list = []
1770 for idx, shape in enumerate(shapeList):
1771 dtype = dtypeList[idx]
1772 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001773 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001774 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1775 else:
1776 # Create the indices array
1777 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1778 arr = []
1779 for n in range(shape[0]):
1780 # Get a shuffled list of output indices (0 to K-1) and
1781 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001782 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001783 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1784 # To match old functionality - create indices as CONST
1785 tens_ser_list.append(
1786 testGen.ser.addConst(shape, dtype, indices_arr)
1787 )
1788
1789 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1790
1791 else:
1792 # ERROR_IF or floating point test
1793 # Use inclusive values upto index K for indices tensor
1794 data_range_list = (
1795 {"range": None},
1796 {"range": (0, K - 1)},
1797 {"range": None},
1798 )
1799 argsDict["data_range_list"] = data_range_list
1800
1801 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001802 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001803 )
1804
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001805
1806class TosaArgGen:
1807 """Argument generators create exhaustive or random lists of attributes for
1808 operators that take attributes or other parameters.
1809
1810 The return value is a list of (descriptive_name, [arglist]) tuples where
1811 the descriptive_name is appended to the test name and the arglist is expanded
1812 as arguments to the operator build function.
1813 """
1814
1815 def __init__(self):
1816 pass
1817
1818 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001819 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001820 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001821 if (
1822 error_name is None
1823 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1824 and gtu.dtypeIsSupportedByCompliance(dtype)
1825 ):
Tai Ly60dc48c2024-03-08 22:19:41 +00001826 if gtu.dtypeIsFloat(dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001827 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1828 else:
1829 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1830 else:
1831 # Error test or No data generator types listed - assume random
1832 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1833
1834 # Expand arg list with other data generator types
1835 new_arg_list = []
1836 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001837 for arg_str, args_dict in arg_list:
evacha019c96eef2024-02-07 11:21:55 +00001838
1839 if dg_type == gtu.DataGenType.FULL_RANGE:
1840 tensor_size = gtu.product(shapeList[0])
1841 if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1842 # Large enough tensor data size for full range, add a single test
1843 num_test_sets = 0
1844 else:
1845 # Not enough data size for full range of values, revert to random numbers
1846 dg_type = gtu.DataGenType.PSEUDO_RANDOM
1847
Jeremy Johnson1271c442023-09-05 11:39:26 +01001848 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001849 if error_name is None:
1850 num_test_sets = (
1851 args_dict["num_test_sets"]
1852 if "num_test_sets" in args_dict
1853 else 0
1854 )
1855 else:
evacha019c96eef2024-02-07 11:21:55 +00001856 # Add single test for pseudo random
Jeremy Johnson30476252023-11-20 16:15:30 +00001857 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001858
1859 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1860 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001861 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001862 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001863 shape_info = (
1864 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1865 if "shape" in args_dict
1866 else ""
1867 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001868 logger.info(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001869 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001870 )
1871 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001872 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001873 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001874 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001875
Jeremy Johnson30476252023-11-20 16:15:30 +00001876 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1877
1878 if num_test_sets > 0:
1879 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001880 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
1881 set_args_dict = args_dict.copy()
1882 set_args_dict["s"] = s
1883 set_args_dict["dg_type"] = dg_type
1884 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001885 else:
1886 # Default is a single test
evacha019c96eef2024-02-07 11:21:55 +00001887 new_args_dict = args_dict.copy()
1888 new_args_dict["dg_type"] = dg_type
1889 new_arg_list.append((arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001890
1891 return new_arg_list
1892
1893 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001894 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001895 """A trivial argument generator for operators that don't take any
1896 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001897 arg_list = TosaArgGen._add_data_generators(
1898 testGen,
1899 opName,
evacha019c96eef2024-02-07 11:21:55 +00001900 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001901 dtype,
1902 [("", {})],
1903 error_name,
1904 )
1905 # Return list of tuples: (arg_str, args_dict)
1906 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001907
1908 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001909 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001910 """Pow operator needs different test sets to cover random numbers
1911 without creating NaNs or Infs"""
1912 arg_list = TosaArgGen._add_data_generators(
1913 testGen,
1914 opName,
evacha019c96eef2024-02-07 11:21:55 +00001915 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001916 dtype,
1917 [("", {"num_test_sets": 3})],
1918 error_name,
1919 )
1920 # Return list of tuples: (arg_str, args_dict)
1921 return arg_list
1922
1923 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001924 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001925 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001926 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001927 shape = shapeList[0]
1928
1929 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001930 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001931 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001932 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001933 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001934 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001935 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001936 # Create tests for each dimension
1937 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001938
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001939 opid = testGen.TOSA_OP_LIST[opName]["op"]
1940
1941 for a in axes:
1942 args_dict = {"axis": int(a)}
1943 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001944 output_shape = shape.copy()
1945 if error_name is None:
1946 # It only matters that we calculate the dot_products correctly
1947 # for non error_if tests as they should never be run
1948 output_shape[a] = 1
1949 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001950 args_dict["shape"] = shape
1951 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1952 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1953
1954 arg_list.append(("axis{}".format(a), args_dict))
1955
1956 arg_list = TosaArgGen._add_data_generators(
1957 testGen,
1958 opName,
evacha019c96eef2024-02-07 11:21:55 +00001959 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001960 dtype,
1961 arg_list,
1962 error_name,
1963 )
1964 # Return list of tuples: (arg_str, args_dict)
1965 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001966
1967 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001968 def _calculate_sparsity(num_tests, sparsity_factor):
1969 sparsity = num_tests // sparsity_factor + 1
1970 # If there are only a small number of tests, just select them all
1971 if sparsity < 13:
1972 sparsity = 1
1973 # To get a variety of parameter combinations sparsity should not be a
1974 # multiple of 2, 3 or 5
1975 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1976 sparsity += 1
1977 return sparsity
1978
1979 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001980 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001981 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001982 arg_list = []
1983
Jeremy Johnson0c716862023-04-13 17:18:19 +01001984 if testGen.args.level8k and error_name is not None:
1985 # Don't produce negative large tests
1986 return arg_list
1987
1988 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001989 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001990 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001991 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001992
Tai Lyf36f2562024-03-14 16:21:29 +00001993 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
1994
1995 if error_name == ErrorIf.WrongAccumulatorType:
1996 accum_dtypes = (
1997 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
1998 )
James Ward8b390432022-08-12 20:48:56 +01001999
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002000 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01002001 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002002 depthwise = opName.startswith("depthwise")
2003
2004 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01002005 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002006 if error_name != ErrorIf.WrongRank:
2007 assert len(ifm_shape) == rank
2008 assert len(filter_shape) == rank
2009
Jeremy Johnson0c716862023-04-13 17:18:19 +01002010 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002011 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002012 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002013 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002014 # compliance size - KS
2015 k_size = gtu.product(k_shape)
2016 if not depthwise:
2017 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002018
Jeremy Johnson0c716862023-04-13 17:18:19 +01002019 if not testGen.args.level8k:
2020 # Generate comprehensive argument lists
2021 # - except for named errors, which use specific invalid value(s)
2022 if error_name == ErrorIf.PadSmallerZero:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002023 p_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002024 else:
2025 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
2026 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
2027 if error_name == ErrorIf.StrideSmallerOne:
2028 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002029 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002030 else:
2031 # Stride must be greater than 1 to force non-integer error
2032 startStride = (
2033 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002034 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002035 s_vals = [
2036 x for x in range(startStride, testGen.args.max_conv_stride + 1)
2037 ]
2038 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2039 if error_name == ErrorIf.DilationSmallerOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002040 d_vals = [rng.choice(range(-5, 1))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002041 else:
2042 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
2043 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002044
Jeremy Johnson0c716862023-04-13 17:18:19 +01002045 if not error_name and testGen.args.oversize:
2046 # add some oversize argument values
2047 if max(ifm_shape) < 64:
2048 bigPadding = 9
2049 paddings.update(
2050 {
2051 x
2052 for x in itertools.product(
2053 *([[0, bigPadding]] * (k_rank * 2))
2054 )
2055 }
2056 )
2057 bigStride = 8
2058 strides.update(
2059 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2060 )
2061 bigDilation = 7
2062 dilations.update(
2063 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2064 )
2065 max_dim_size = None
2066
2067 # There are too many parameter combinations, so generate them sparsely,
2068 # very sparse for negative tests
2069 sparsity_factor = 2 if error_name else 120
2070 sparsity = TosaArgGen._calculate_sparsity(
2071 len(paddings) * len(strides) * len(dilations), sparsity_factor
2072 )
2073 else:
2074 # Only test 8k levels boundaries
2075 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2076 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2077 bigPadding = bigKernel
2078
2079 dilation_shape = [1] * k_rank
2080 pad_shape = [0] * k_rank * 2
2081 if conv3d:
2082 # Small stride apart from for big kernel (see below) to keep
2083 # tensor size/calculation small
2084 stride_shape = [1] * k_rank
2085 for idx in range(k_rank):
2086 pad_offset = idx * 2
2087 if k_shape[idx] == bigKernel:
2088 # Padding shape needs to account for tensor shape
2089 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2090 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2091 # Big stride to reduce output size
2092 stride_shape[idx] = bigKernel
2093 else:
2094 # Account for kernel size
2095 pad_shape[pad_offset] = k_shape[idx] - 1
2096 else:
2097 # Always have a large stride with extra padding and dilation to keep
2098 # tensor calculation reasonable
2099 stride_shape = [bigKernel] * k_rank
2100 for idx in range(k_rank):
2101 # Dilation shape must account for kernel size
2102 dilation_shape[idx] = bigKernel // k_shape[idx]
2103 # Padding shape needs to accommodate tensor/kernel & dilation
2104 pad_offset = idx * 2
2105 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2106 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2107
2108 strides = {tuple(stride_shape)}
2109 dilations = {tuple(dilation_shape)}
2110 paddings = {tuple(pad_shape)}
2111 # Create a limit for the output dimensions size
2112 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2113
2114 # Currently allow all combinations that are reasonable size
2115 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002116
2117 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002118 for a in accum_dtypes:
2119 for s in sorted(list(strides)):
2120 for p in sorted(list(paddings)):
2121 for d in sorted(list(dilations)):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002122 if (
Tai Lyf36f2562024-03-14 16:21:29 +00002123 n % sparsity == 0
2124 # the padded shape must exceed the dilation * kernel to get a positive
2125 # sized output shape
2126 and (ifm_shape[1] - 1 + p[0] + p[1])
2127 > d[0] * (k_shape[0] - 1)
2128 and (ifm_shape[2] - 1 + p[2] + p[3])
2129 > d[1] * (k_shape[1] - 1)
2130 and (
2131 k_rank < 3
2132 or (
2133 (ifm_shape[3] - 1 + p[4] + p[5])
2134 > d[2] * (k_shape[2] - 1)
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002135 )
2136 )
Tai Lyf36f2562024-03-14 16:21:29 +00002137 ):
2138 remainders = []
2139 outputs = []
2140 for index in range(k_rank):
2141 pad_offset = index * 2
2142 partial = (
2143 ifm_shape[index + 1]
2144 - 1
2145 + p[pad_offset]
2146 + p[pad_offset + 1]
2147 - (k_shape[index] - 1) * d[index]
2148 )
2149 remainders.append(partial % s[index])
2150 outputs.append((partial // s[index]) + 1)
2151
2152 if (
2153 # the parameters must produce integer exact output
2154 error_name != ErrorIf.ConvOutputShapeNonInteger
2155 and max(remainders) == 0
2156 ) or (
2157 error_name == ErrorIf.ConvOutputShapeNonInteger
2158 and max(remainders) > 0
2159 ):
2160 if (
2161 max_dim_size is not None
2162 and max(outputs) >= max_dim_size
2163 ):
2164 # Test will consume too much memory - skip it
2165 continue
2166
2167 # Compliance - number of dot product calculations
2168 if depthwise:
2169 # N*OH*OW*C*M
2170 dots = gtu.product(
2171 (ifm_shape[0], *outputs, *filter_shape[2:])
2172 )
2173 else:
2174 # N*OH*OW*OC or N*OD*OH*OW*OC
2175 dots = gtu.product(
2176 (ifm_shape[0], *outputs, filter_shape[0])
2177 )
2178 args_dict = {
2179 "acc_type": a,
2180 "stride": s,
2181 "pad": p,
2182 "dilation": d,
2183 "kernel": k_shape,
2184 "ks": k_size,
2185 "dot_products": dots,
2186 "shape": ifm_shape,
2187 }
2188
2189 # Support for larger values than 9 needs different delimiter
2190 delim = "" if max(s + p + d) <= 9 else "x"
2191 arg_list.append(
2192 (
2193 "acc{}_st{}_pad{}_dilat{}".format(
2194 testGen.typeStr(a),
2195 delim.join([str(x) for x in s]),
2196 delim.join([str(x) for x in p]),
2197 delim.join([str(x) for x in d]),
2198 ),
2199 args_dict,
2200 )
2201 )
2202 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002203
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002204 arg_list = TosaArgGen._add_data_generators(
2205 testGen,
2206 opName,
evacha019c96eef2024-02-07 11:21:55 +00002207 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002208 dtypes[0],
2209 arg_list,
2210 error_name,
2211 )
2212 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002213 return arg_list
2214
2215 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002216 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002217
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002218 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002219 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002220
2221 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002222 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002223 elif error_name == ErrorIf.WrongInputType:
2224 # Pick some potentially correct output dtype if input type is incorrect
2225 accum_dtype = DType.INT32
2226 else:
Tai Lyf36f2562024-03-14 16:21:29 +00002227 accum_dtype = dtypes[-1] # use output dtype as accum_dtype
James Ward8b390432022-08-12 20:48:56 +01002228
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002229 # Set up compliance info
2230 args_dict = {
2231 "acc_type": accum_dtype,
2232 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2233 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2234 "shape": shapeList[0],
2235 }
2236
2237 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2238
2239 arg_list = TosaArgGen._add_data_generators(
2240 testGen,
2241 opName,
evacha019c96eef2024-02-07 11:21:55 +00002242 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002243 input_dtype,
2244 arg_list,
2245 error_name,
2246 )
2247 # Return list of tuples: (arg_str, args_dict)
2248 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002249
2250 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002251 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002252 # Get valid accumulate type(s)
2253 if dtype == DType.INT8:
2254 accum_dtypes = [DType.INT32]
2255 elif dtype == DType.INT16:
2256 accum_dtypes = [DType.INT48]
2257 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002258 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002259 elif dtype == DType.BF16:
2260 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002261 elif dtype == DType.FP32:
2262 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002263 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2264 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002265 elif error_name is None:
2266 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2267
2268 if error_name == ErrorIf.WrongOutputType:
2269 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002270 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002271 elif error_name == ErrorIf.WrongInputType:
2272 # Pick some potentially correct output dtype if input type is incorrect
2273 accum_dtypes = [DType.INT32]
2274
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002275 # Set up compliance info
2276 args_dict = {
2277 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2278 # Set dot_products = N*H*W
2279 "dot_products": gtu.product(
2280 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2281 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002282 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002283 }
2284
2285 # Create arg tuple of string and dict
2286 arg_list = []
2287 for a in accum_dtypes:
2288 d = args_dict.copy()
2289 d["acc_type"] = a
2290 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002291
2292 arg_list = TosaArgGen._add_data_generators(
2293 testGen,
2294 opName,
evacha019c96eef2024-02-07 11:21:55 +00002295 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002296 dtype,
2297 arg_list,
2298 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002299 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002300 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002301 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002302
2303 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002304 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002305 arg_list = []
2306
Jeremy Johnson0c716862023-04-13 17:18:19 +01002307 if testGen.args.level8k and error_name is not None:
2308 # Don't produce negative large tests
2309 return arg_list
2310
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002311 ifm_shape = shapeList[0]
2312 filter_shape = shapeList[1]
2313
Tai Lyf36f2562024-03-14 16:21:29 +00002314 accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
2315
2316 if error_name == ErrorIf.WrongAccumulatorType:
2317 accum_dtypes = (
2318 [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
2319 )
James Ward8b390432022-08-12 20:48:56 +01002320
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002321 # Must be rank 4
2322 if error_name != ErrorIf.WrongRank:
2323 assert len(ifm_shape) == 4
2324 assert len(filter_shape) == 4
2325
Jeremy Johnson0c716862023-04-13 17:18:19 +01002326 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002327 # compliance size - KS
2328 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002329
Jeremy Johnson0c716862023-04-13 17:18:19 +01002330 if not testGen.args.level8k:
2331 # Generate comprehensive argument lists
2332 # - except for named errors, which use specific invalid value(s)
2333 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2334 if error_name == ErrorIf.PadLargerEqualKernel:
2335 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002336 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002337 else:
2338 p_vals = [
2339 x
2340 for x in range(
2341 smallest_padding_size, testGen.args.max_conv_padding + 1
2342 )
2343 ]
2344 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2345 if error_name == ErrorIf.StrideSmallerOne:
2346 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002347 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002348 else:
2349 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2350 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002351
Jeremy Johnson0c716862023-04-13 17:18:19 +01002352 if not error_name and testGen.args.oversize:
2353 # add some oversize argument values
2354 if max(ifm_shape) < 64:
2355 bigPadding = 9
2356 paddings.update(
2357 {
2358 x
2359 for x in itertools.product(
2360 *([[smallest_padding_size, bigPadding]] * 4)
2361 )
2362 }
2363 )
2364 bigStride = 8
2365 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2366
2367 # There are too many parameter combinations, so generate them sparsely,
2368 # very sparse for negative tests
2369 sparsity_factor = 2 if error_name else 10
2370 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2371 # If there are only a small number of tests, just select them all
2372 if sparsity < 13:
2373 sparsity = 1
2374 # To get a variety of parameter combinations sparsity should not be a
2375 # multiple of 2, 3 or 5
2376 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2377 sparsity += 1
2378 else:
2379 # Only test 8k levels boundaries
2380 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2381 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2382 bigPadding = bigKernel
2383
2384 pad_shape = [0] * (len(k_shape) * 2)
2385 stride_shape = [1] * len(k_shape)
2386 # The point at which input dimension combined with the stride will
2387 # create large output sizes!
2388 LARGE_SIZE = 2
2389 for idx in range(len(k_shape)):
2390 pad_offset = idx * 2
2391 if k_shape[idx] == bigKernel:
2392 # Set large stride
2393 stride_shape[idx] = bigKernel
2394 # Use negative output padding to reduce shape size
2395 pad_shape[pad_offset] = -(bigPadding - 1)
2396 if ifm_shape[idx + 1] > LARGE_SIZE:
2397 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2398 else:
2399 # The other dimension should be the bigKernel
2400 alt_idx = 1 - idx
2401 if (
2402 k_shape[alt_idx] == bigKernel
2403 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2404 ):
2405 # As the input is small, the large stride won't
2406 # affect the output so we can add some padding
2407 pad_shape[pad_offset + 1] = bigPadding
2408
2409 strides = {tuple(stride_shape)}
2410 paddings = {tuple(pad_shape)}
2411
2412 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002413 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002414
2415 n = 0
Tai Lyf36f2562024-03-14 16:21:29 +00002416 for a in accum_dtypes:
2417 for s in sorted(list(strides)):
2418 for p in sorted(list(paddings)):
2419 if n % sparsity == 0:
2420 # Determine the output shape
2421 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2422 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
2423 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002424
Tai Lyf36f2562024-03-14 16:21:29 +00002425 # N*OH*OW*OC
2426 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2427 args_dict = {
2428 "acc_type": a,
2429 "stride": s,
2430 "pad": p,
2431 "kernel": k_shape,
2432 "ks": k_size,
2433 "dot_products": dots,
2434 "shape": ifm_shape,
2435 "out_shape": os,
2436 }
Jeremy Johnson95a67102024-01-10 14:16:39 +00002437
Tai Lyf36f2562024-03-14 16:21:29 +00002438 # Support for larger values than 9 needs different delimiter
2439 delim = "" if max(s + p) <= 9 else "x"
2440 arg_list.append(
2441 (
2442 "acc{}_st{}_pad{}_os{}".format(
2443 testGen.typeStr(a),
2444 delim.join([str(x) for x in s]),
2445 delim.join([str(x) for x in p]),
2446 "x".join([str(x) for x in os]),
2447 ),
2448 args_dict,
2449 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002450 )
Tai Lyf36f2562024-03-14 16:21:29 +00002451 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002452
Jeremy Johnson95a67102024-01-10 14:16:39 +00002453 arg_list = TosaArgGen._add_data_generators(
2454 testGen,
2455 opName,
evacha019c96eef2024-02-07 11:21:55 +00002456 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002457 dtypes[0],
2458 arg_list,
2459 error_name,
2460 )
2461 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002462 return arg_list
2463
2464 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002465 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002466 rank = len(shapeList[0])
2467
2468 # Exhaustively test combinations of padding on each side of each dimension
2469 # - the range of padding values is defined by pad_min and pad_max
2470 # - for padding >9, the name format needs to be more distinctive
2471 pad_min, pad_max = 0, 1
2472 pad_values = [x for x in range(pad_min, pad_max + 1)]
2473 if error_name == ErrorIf.PadSmallerZero:
2474 pad_values = [x for x in range(-2, 0)]
2475 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2476 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2477
2478 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002479 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002480 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002481 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002482 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002483 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002484 else:
2485 return []
2486
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002487 list_shape_pad_values = list(shape_pad_values)
2488 # If we are producing tests for rank 6 or greater use sparsity
2489 if len(list_shape_pad_values) > 1024:
2490 sparsity_factor = 2 if error_name else 120
2491 sparsity = TosaArgGen._calculate_sparsity(
2492 len(list_shape_pad_values), sparsity_factor
2493 )
2494 else:
2495 sparsity = 1
2496
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002497 # Build arg list
2498 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002499 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002500 paddings = list(paddings)
2501 args_valid = True
2502
2503 if error_name == ErrorIf.PadSmallerZero:
2504 # Prevent negative output shapes while ensuring still testing for negative padding
2505 for i in range(rank):
2506 dim_after_padding = (
2507 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2508 )
2509 if dim_after_padding < 1:
2510 paddings[i] = (0, 0)
2511 if all([p > -1 for p in paddings[i]]):
2512 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002513 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002514 name = "pad"
2515 for r in range(rank):
2516 before, after = paddings[r]
2517 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002518 args_dict = {
2519 "pad": np.array(paddings),
2520 "pad_const_int": pad_const_int,
2521 "pad_const_fp": pad_const_fp,
2522 }
2523 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002524
2525 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002526 logger.info(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002527
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002528 arg_list = TosaArgGen._add_data_generators(
2529 testGen,
2530 opName,
evacha019c96eef2024-02-07 11:21:55 +00002531 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002532 dtype,
2533 arg_list,
2534 error_name,
2535 )
2536
2537 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002538 return arg_list
2539
2540 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002541 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542 arg_list = []
2543
2544 shape = shapeList[0]
2545 if error_name != ErrorIf.WrongRank:
2546 assert len(shape) == 4
2547
Jeremy Johnson0c716862023-04-13 17:18:19 +01002548 test_level8k = testGen.args.level8k and error_name is None
2549
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002550 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002551 startKernel = 2
2552 startPad = 0
2553 if not test_level8k:
2554 # Generate comprehensive argument lists
2555 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2556 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2557 # Stride must be greater than 1 to force non-integer error
2558 s_vals = [
2559 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2560 ]
2561 strides = {x for x in itertools.product(*([s_vals] * 2))}
2562 k_vals = [
2563 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2564 ]
2565 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2566 max_dim_size = None
2567 else:
2568 # Only test 8k levels
2569 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2570 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2571 strides = {(1, bigStride), (bigStride, 4)}
2572 kernels = {(1, bigKernel), (bigKernel, 3)}
2573 paddings = set()
2574 for s in sorted(list(strides)):
2575 for k in sorted(list(kernels)):
2576 padding = []
2577 for idx in range(len(k)):
2578 total_padding = s[idx] - shape[idx + 1] + k[idx]
2579 while total_padding < 0:
2580 # Must meet: shape + padding > kernel
2581 total_padding += s[idx]
2582 if total_padding < k[idx]:
2583 padding.extend([0, total_padding])
2584 else:
2585 # Note this may produce padding >= k[idx] which is not
2586 # allowed - but will be ignored in the creation loop below
2587 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2588 paddings.add(tuple(padding))
2589 # Create a limit for the output dimensions size
2590 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002591
James Ward8b390432022-08-12 20:48:56 +01002592 if opName == "max_pool2d":
2593 accum_dtypes = [None] # max_pool has no accumulate dtype
2594 elif dtype == DType.INT8 or dtype == DType.INT16:
2595 accum_dtypes = [DType.INT32]
2596 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002597 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002598 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002599 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002600 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2601 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002602 elif error_name is None:
2603 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2604 else:
2605 # Set to something for the ErrorIf case which has
2606 # incorrect input data-type
2607 accum_dtypes = [DType.INT32]
2608
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002609 if error_name == ErrorIf.WrongAccumulatorType:
2610 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2611
Jeremy Johnson0c716862023-04-13 17:18:19 +01002612 if not test_level8k:
2613 if testGen.args.oversize:
2614 # add some oversize argument values
2615 bigStride = 7
2616 bigKernel = 9
2617 strides.update(
2618 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002619 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002620 kernels.update(
2621 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2622 )
2623 if max(shape) < 64:
2624 # padding must be less than the kernel size
2625 bigPadding = bigKernel - 1
2626 paddings.update(
2627 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2628 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002629
Jeremy Johnson0c716862023-04-13 17:18:19 +01002630 # There are too many parameter combinations, so generate them sparsely,
2631 # very sparse for negative tests
2632 sparsity_factor = 2 if error_name else 500
2633 sparsity = (
2634 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2635 )
2636 else:
2637 # We have already limited test output combinations for 8k tests
2638 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002639
James Ward8b390432022-08-12 20:48:56 +01002640 arg_str = (
2641 "acc{}_st{}_kern{}_pad{}"
2642 if accum_dtypes[0] is not None
2643 else "st{}_kern{}_pad{}"
2644 )
2645
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002646 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002647 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002648 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002649
2650 # Support for larger values than 9 needs different delimiter
2651 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002652 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002653 delim.join([str(x) for x in stride]),
2654 delim.join([str(x) for x in kern]),
2655 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002656 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002657 args_dict = {
2658 "stride": stride,
2659 "pad": pad,
2660 "kernel": kern,
2661 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002662 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002663 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2664 }
James Ward8b390432022-08-12 20:48:56 +01002665
2666 if accum is not None:
2667 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002668 args_dict["acc_type"] = accum
2669 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002670
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002671 n = 0
James Ward8b390432022-08-12 20:48:56 +01002672 for a in accum_dtypes:
2673 for s in sorted(list(strides)):
2674 for p in sorted(list(paddings)):
2675 for k in sorted(list(kernels)):
2676 if error_name in [
2677 ErrorIf.StrideSmallerOne,
2678 ErrorIf.KernelSmallerOne,
2679 ErrorIf.PadSmallerZero,
2680 ErrorIf.PadLargerEqualKernel,
2681 ]:
2682 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002683 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002684 )
James Ward8b390432022-08-12 20:48:56 +01002685 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002686 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002687 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002688 )
James Ward8b390432022-08-12 20:48:56 +01002689 elif (
2690 n % sparsity == 0
2691 # padding must not exceed the kernel size
2692 and p[0] < k[0]
2693 and p[1] < k[0]
2694 and p[2] < k[1]
2695 and p[3] < k[1]
2696 # the padded shape must exceed the kernel size
2697 and (shape[1] + p[0] + p[1]) > k[0]
2698 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002699 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002700 partial_h = shape[1] + p[0] + p[1] - k[0]
2701 partial_w = shape[2] + p[2] + p[3] - k[1]
2702 remainder_h = partial_h % s[0]
2703 remainder_w = partial_w % s[1]
2704 output_h = partial_h // s[0] + 1
2705 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002706 logger.debug(
2707 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2708 )
James Ward8b390432022-08-12 20:48:56 +01002709 if (
2710 # the parameters must produce integer exact output
2711 error_name != ErrorIf.PoolingOutputShapeNonInteger
2712 and remainder_h == 0
2713 and remainder_w == 0
2714 ) or (
2715 error_name == ErrorIf.PoolingOutputShapeNonInteger
2716 and (remainder_h != 0 or remainder_w != 0)
2717 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002718 if (
2719 max_dim_size is not None
2720 and max(output_h, output_w) > max_dim_size
2721 ):
2722 # Test will consume too much memory - skip it
2723 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002724 # Dot products = N*OH*OW*C
2725 dp = gtu.product(
2726 (shape[0], output_h, output_w, shape[3])
2727 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002728 arg_list.append(
2729 get_arg_list_element(a, s, p, k, dp, shape)
2730 )
James Ward8b390432022-08-12 20:48:56 +01002731 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002732
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002733 # Now add data generator types
2734 arg_list = TosaArgGen._add_data_generators(
2735 testGen,
2736 opName,
evacha019c96eef2024-02-07 11:21:55 +00002737 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002738 dtype,
2739 arg_list,
2740 error_name,
2741 )
2742
2743 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002744 return arg_list
2745
2746 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002747 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002748 arg_list = []
2749
2750 # Enumerate the output types here
2751 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002752 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002753 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002754 dtypeList = [
2755 DType.BOOL,
2756 DType.INT16,
2757 DType.INT32,
2758 DType.FP16,
2759 DType.BF16,
2760 DType.FP32,
2761 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002762 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002763 dtypeList = [
2764 DType.BOOL,
2765 DType.INT8,
2766 DType.INT32,
2767 DType.FP16,
2768 DType.BF16,
2769 DType.FP32,
2770 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002771 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002772 dtypeList = [
2773 DType.BOOL,
2774 DType.INT8,
2775 DType.INT16,
2776 DType.FP16,
2777 DType.BF16,
2778 DType.FP32,
2779 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002780 elif inDtype == DType.BOOL:
2781 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002782 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002783 dtypeList = [
2784 DType.INT8,
2785 DType.INT16,
2786 DType.INT32,
2787 DType.FP32,
2788 DType.FP8E4M3,
2789 DType.FP8E5M2,
2790 ]
James Ward24dbc422022-10-19 12:20:31 +01002791 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002792 dtypeList = [
2793 DType.INT8,
2794 DType.INT16,
2795 DType.INT32,
2796 DType.FP32,
2797 DType.FP8E4M3,
2798 DType.FP8E5M2,
2799 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002800 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002801 dtypeList = [
2802 DType.INT8,
2803 DType.INT16,
2804 DType.INT32,
2805 DType.FP16,
2806 DType.BF16,
2807 DType.FP8E4M3,
2808 DType.FP8E5M2,
2809 ]
2810 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2811 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002812 elif error_name == ErrorIf.WrongInputType:
2813 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002814 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002815 else:
2816 raise Exception("Unexpected input dtype: {}".format(inDtype))
2817
2818 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002819 arg_list.append(
2820 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2821 )
2822
2823 # Now add data generator types
2824 arg_list = TosaArgGen._add_data_generators(
2825 testGen,
2826 opName,
evacha019c96eef2024-02-07 11:21:55 +00002827 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002828 dtype,
2829 arg_list,
2830 error_name,
2831 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002832
2833 return arg_list
2834
2835 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002836 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002837 arg_list = []
2838
2839 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002840 for outDtype in [
2841 DType.UINT8,
2842 DType.INT8,
2843 DType.INT16,
2844 DType.INT32,
2845 DType.UINT16,
2846 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002847 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002848 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002849 and error_name == ErrorIf.OutputZeroPointNotZero
2850 ):
2851 continue
2852 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002853 outDtype != DType.UINT16
2854 and error_name == ErrorIf.U16OutputZeroPointNotValid
2855 ) or (
2856 inDtype != DType.UINT16
2857 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002858 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002859 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002860 continue
2861 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002862 inDtype == DType.UINT8
2863 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002864 and error_name != ErrorIf.WrongOutputType
2865 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002866 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2867 continue
2868 if (
2869 inDtype not in [DType.INT8, DType.INT16]
2870 and outDtype == DType.UINT8
2871 and error_name != ErrorIf.WrongOutputType
2872 ):
2873 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2874 continue
2875 if (
2876 inDtype == DType.UINT16
2877 and outDtype != DType.INT16
2878 and error_name != ErrorIf.WrongOutputType
2879 ):
2880 # The only output dtype for UINT16 is INT16, skip all others
2881 continue
2882 if (
2883 inDtype != DType.INT16
2884 and outDtype == DType.UINT16
2885 and error_name != ErrorIf.WrongOutputType
2886 ):
2887 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002888 continue
2889 if (
2890 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002891 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002892 ):
2893 continue
2894
2895 for scale32 in [False, True]:
2896 if error_name == ErrorIf.ScaleTrue and not scale32:
2897 continue
2898 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2899 continue
2900 for double_round in [False, True]:
2901 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2902 continue
2903 for per_channel in [False, True]:
2904
2905 if (
2906 inDtype == DType.INT48
2907 and scale32
2908 and error_name != ErrorIf.ScaleTrue
2909 ):
2910 # Illegal condition. Must be scale32=False
2911 continue
2912 if (
2913 double_round
2914 and not scale32
2915 and error_name != ErrorIf.ScaleNotTrue
2916 ):
2917 # Illegal condition. ERROR_IF(!scale32 && double_round)
2918 continue
2919
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002920 if per_channel:
2921 nc = shapeList[0][-1]
2922 else:
2923 nc = 1
2924
2925 in_type_width = gtu.dtypeWidth(inDtype)
2926 out_type_width = gtu.dtypeWidth(outDtype)
2927
2928 # Calculate scale based on:
2929 # scale = a *(2^output_width)/(2^input_width))
2930
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002931 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002932 scale_arr = a * np.float32(
2933 (1 << out_type_width) / (1 << in_type_width)
2934 )
2935
2936 if scale32:
2937 # Cap the scaling at 2^31 - 1 for scale32
2938 scale_arr = np.clip(
2939 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
2940 )
2941 else:
2942 # Cap the scaling at 2^15 - 1 for scale16
2943 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2944
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002945 logger.debug(
2946 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
2947 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002948
2949 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2950 shift_arr = np.int32(np.zeros(shape=[nc]))
2951 for i in range(nc):
2952 (
2953 multiplier_arr[i],
2954 shift_arr[i],
2955 ) = TosaQuantGen.computeMultiplierAndShift(
2956 scale_arr[i], scale32
2957 )
2958
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002959 arg_list.append(
2960 (
2961 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002962 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002963 int(scale32),
2964 int(double_round),
2965 int(per_channel),
2966 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00002967 {
2968 "output_dtype": outDtype,
2969 "scale": scale32,
2970 "double_round": double_round,
2971 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002972 "multiplier": multiplier_arr,
2973 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002974 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002975 )
2976 )
2977
Jeremy Johnson587cc842024-02-08 11:45:44 +00002978 arg_list = TosaArgGen._add_data_generators(
2979 testGen,
2980 opName,
evacha019c96eef2024-02-07 11:21:55 +00002981 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002982 inDtype,
2983 arg_list,
2984 error_name,
2985 )
2986 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002987 return arg_list
2988
2989 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002990 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002991 arg_list = []
2992
2993 if dtype is DType.INT32:
2994 for p in range(testGen.args.num_rand_permutations):
2995
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002996 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002997 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002998 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002999 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003000
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003001 arg_list = TosaArgGen._add_data_generators(
3002 testGen,
3003 opName,
evacha019c96eef2024-02-07 11:21:55 +00003004 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003005 dtype,
3006 arg_list,
3007 error_name,
3008 )
3009 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003010 return arg_list
3011
3012 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003013 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003014 arg_list = []
3015
Jeremy Johnson587cc842024-02-08 11:45:44 +00003016 for round in (True, False):
3017 args_dict = {
3018 "round": round,
3019 }
3020 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003021
Jeremy Johnson587cc842024-02-08 11:45:44 +00003022 arg_list = TosaArgGen._add_data_generators(
3023 testGen,
3024 opName,
evacha019c96eef2024-02-07 11:21:55 +00003025 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003026 dtype,
3027 arg_list,
3028 error_name,
3029 )
3030 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003031 return arg_list
3032
Luke Hutton57287132023-02-06 14:54:18 +00003033 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003034 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003035 arg_list = []
3036
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003037 shape = shapeList[0]
3038 dot_products = gtu.product(shape)
3039 ks = 2 * shape[1] * shape[2] # 2*H*W
3040 for inverse in (True, False):
3041 args_dict = {
3042 "dot_products": dot_products,
3043 "shape": shape,
3044 "ks": ks,
3045 "acc_type": dtype,
3046 "inverse": inverse,
3047 }
3048 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003049
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003050 arg_list = TosaArgGen._add_data_generators(
3051 testGen,
3052 opName,
evacha019c96eef2024-02-07 11:21:55 +00003053 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003054 dtype,
3055 arg_list,
3056 error_name,
3057 )
3058 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003059 return arg_list
3060
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003061 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003062 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003063 arg_list = []
3064
3065 shape = shapeList[0]
3066 dot_products = gtu.product(shape)
3067 ks = shape[1] * shape[2] # H*W
3068 args_dict = {
3069 "dot_products": dot_products,
3070 "shape": shape,
3071 "ks": ks,
3072 "acc_type": dtype,
3073 }
3074 arg_list.append(("", args_dict))
3075
3076 arg_list = TosaArgGen._add_data_generators(
3077 testGen,
3078 opName,
evacha019c96eef2024-02-07 11:21:55 +00003079 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003080 dtype,
3081 arg_list,
3082 error_name,
3083 )
3084 # Return list of tuples: (arg_str, args_dict)
3085 return arg_list
3086
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003087 # Helper function for reshape. Gets some factors of a larger number.
3088 @staticmethod
3089 def getFactors(val, start=1):
3090 factors = []
3091
3092 for i in range(start, int(np.sqrt(val)) + 1):
3093 if (val % i) == 0:
3094 factors.append(i)
3095
3096 return factors
3097
3098 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003099 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003100 arg_list = []
3101
3102 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003103 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003104 factors = TosaArgGen.getFactors(totalElements)
3105
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003106 # Find new shapes up to the number of permutations asked for
3107 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003108 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00003109 # Rank from 1 to TOSA_TENSOR_MAX_RANK
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003110 newRank = rng.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003111 if len(factors) < newRank:
3112 continue
3113
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003114 # escape_counter limits the generation of new shapes to a reasonable time
3115 for escape_counter in range(100):
3116
3117 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003118 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003119 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003120 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003121 for i in range(1, newRank):
3122 # pick rank-1 factors
3123 newShape.append(shuffledFactors[0])
3124 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003125 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003126 TosaArgGen.getFactors(remainingElements)
3127 )
3128 newShape.append(remainingElements)
3129
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003130 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003131 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003132 for name, args_dict in arg_list:
3133 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003134 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003135 break
3136
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003137 if not duplicate:
3138 outShape = "x".join([str(x) for x in newShape])
3139 arg_list.append(
3140 (
3141 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3142 {"new_shape": newShape},
3143 )
3144 )
3145 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003146 break
3147
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003148 # Now add data generator types
3149 arg_list = TosaArgGen._add_data_generators(
3150 testGen,
3151 opName,
evacha019c96eef2024-02-07 11:21:55 +00003152 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003153 dtype,
3154 arg_list,
3155 error_name,
3156 )
3157
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003158 return arg_list
3159
3160 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003161 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003162 arg_list = []
3163
3164 ifm_shape = shapeList[0]
3165
3166 if error_name == ErrorIf.IndexOutsideBounds:
3167 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3168 incorrect_small_index = range(-len(ifm_shape), 0)
3169 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3170 permutations.extend(
3171 [p for p in itertools.permutations(incorrect_small_index)]
3172 )
3173 elif error_name == ErrorIf.IndexUsedTwice:
3174 # Create list with a duplicated index
3175 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003176 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003177 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3178 permutations = [p for p in itertools.permutations(perm_range)]
3179
3180 else:
3181 # Get all permutations
3182 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3183
3184 # Limit to possible permutations from shape dimension or argument setting
3185 limit = min(len(permutations), testGen.args.num_rand_permutations)
3186
3187 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003188 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003189
3190 # Create list of required amount of permutations
3191 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003192 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003193 for p in range(limit)
3194 ]
evacha0198477222024-01-26 12:25:32 +00003195 # Now add data generator types
3196 arg_list = TosaArgGen._add_data_generators(
3197 testGen,
3198 opName,
evacha019c96eef2024-02-07 11:21:55 +00003199 shapeList,
evacha0198477222024-01-26 12:25:32 +00003200 dtype,
3201 arg_list,
3202 error_name,
3203 )
3204 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003205 return arg_list
3206
3207 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003208 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003209 arg_list = []
3210
3211 ifm_shape = shapeList[0]
3212 rank = len(ifm_shape)
3213
3214 for p in range(testGen.args.num_rand_permutations):
3215 start = []
3216 size = []
3217
3218 valid = True
3219
3220 for i in range(rank):
3221 if ifm_shape[i] > 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003222 start.append(rng.randInt(0, ifm_shape[i]))
3223 size.append(rng.randInt(0, ifm_shape[i] - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003224
3225 # Invalid slice size?
3226 if size[i] == 0:
3227 valid = False
3228 else:
3229 start.append(0)
3230 size.append(1)
3231
3232 if valid:
3233 # If ERROR_IF test required then incorrect start, size will be returned
3234 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003235 rng, error_name, ifm_shape, start, size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003236 )
evacha017f7d4252024-01-24 12:08:09 +00003237 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3238 # Now add data generator types
3239 arg_list = TosaArgGen._add_data_generators(
3240 testGen,
3241 opName,
evacha019c96eef2024-02-07 11:21:55 +00003242 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003243 dtype,
3244 arg_list,
3245 error_name,
3246 )
3247 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003248 return arg_list
3249
3250 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003251 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003252 arg_list = []
3253
3254 ifm_shape = shapeList[0]
3255 rank = len(ifm_shape)
3256
3257 for p in range(testGen.args.num_rand_permutations):
3258
3259 # Pick a few random, but small multiple values
3260 # because otherwise this has a tendency to generate
3261 # enormous tensors
3262 multiples = []
3263 for i in range(rank):
3264 if ifm_shape[i] > 1000:
3265 # Multiple of 1 if ifm_shape dimension is large to reduce
3266 # tensor size
3267 multiples.append(1)
3268 elif max(ifm_shape) > 1000:
3269 multiples.append(2)
3270 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003271 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003272 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003273
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003274 # Now add data generator types
3275 arg_list = TosaArgGen._add_data_generators(
3276 testGen,
3277 opName,
evacha019c96eef2024-02-07 11:21:55 +00003278 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003279 dtype,
3280 arg_list,
3281 error_name,
3282 )
3283 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 return arg_list
3285
3286 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003287 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003288 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003290
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003291 def get_aspect_ratio_resize_params():
3292 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003293 aspect_ratio = rng.choice(common_aspect_ratios)
3294 invert = rng.choice((False, True))
3295 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003296
3297 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3298 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3299 scale_y_d = scale_x_d = 1
3300 offset_x = offset_y = 0
3301
3302 if letterbox:
3303 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003304 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003305 border_x = 0
3306 else:
3307 # Pillarboxing
3308 border_y = 0
3309 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003310 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003311
3312 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3313 offset = (offset_y, offset_x)
3314 border = (border_y, border_x)
3315
3316 return scale, offset, border
3317
3318 def get_upscale_downscale_params():
3319 valid_params = False
3320 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003321 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003322
3323 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003324 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003325
3326 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003327 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003328 scale_x_d = scale_y_d = 1
3329 scale_x_n = scale_y_n = (
3330 1 << shift if origin_sampling else 2 << shift
3331 )
3332 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3333 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3334 else:
3335 scale_x_n = 1
3336 scale_y_n = 1
3337
3338 # Return list of valid scale_*_d values (max value 4) given input dim shape
3339 def get_valid_denom(ifm_dim):
3340 return [x for x in range(1, 5) if ifm_dim % x == 1]
3341
3342 # Generate list of valid downscale values and choose one randomly
3343 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3344 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3345
3346 if not valid_scale_y_ds and not valid_scale_x_ds:
3347 # Bad parameters, skip
3348 continue
3349
3350 if not valid_scale_y_ds:
3351 scale_y_d = 1
3352 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003353 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003354
3355 if not valid_scale_x_ds:
3356 scale_x_d = 1
3357 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003358 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003359
3360 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003361 offset_y = rng.randInt(0, 16 * scale_y_n)
3362 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003363 valid_params = True
3364
3365 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3366 offset = (offset_y, offset_x)
3367 border = (border_y, border_x)
3368 return scale, offset, border
3369
3370 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003371 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3372 scale = scale_n / scale_d
3373 if scale > max_scale:
3374 factor = scale / max_scale
3375 new_scale_d = math.ceil(scale_d * factor)
3376 assert scale_n / new_scale_d <= max_scale
3377 scale_d = new_scale_d
3378 return scale_d
3379
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003380 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003381 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3382 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003383
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003384 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3385 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003386
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003387 scale_y_d = fix_scale_to_max_scale(
3388 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3389 )
3390 scale_x_d = fix_scale_to_max_scale(
3391 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3392 )
3393
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003394 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003395 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3396 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3397 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3398 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003399
3400 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3401 offset = (offset_y, offset_x)
3402 border = (border_y, border_x)
3403 return scale, offset, border
3404
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003405 def get_level_8k_params():
3406 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003407 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003408 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3409 )
3410 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3411 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003412 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003413 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003414 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003415 if switch:
3416 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3417 else:
3418 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3419
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003420 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3421 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003422 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003423 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3424 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003425 border = (border_y, border_x)
3426 return scale, offset, border
3427
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003428 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003429 # Exclude illegal {mode, type} configurations. Pick legal output types
3430 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3431 outputDTypeList = [DType.INT8]
3432 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3433 outputDTypeList = [DType.INT16]
3434 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3435 outputDTypeList = [DType.INT32]
3436 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3437 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003438 elif dtype == DType.FP16:
3439 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003440 elif dtype == DType.BF16:
3441 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003442 elif dtype == DType.FP32:
3443 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003444 elif dtype == DType.FP8E4M3:
3445 outputDTypeList = [DType.FP8E4M3]
3446 elif dtype == DType.FP8E5M2:
3447 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003448 elif error_name == ErrorIf.WrongInputType:
3449 # If an incorrect input type is used then we set a 'correct'
3450 # output type to avoid other errors
3451 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3452 else:
3453 continue
3454
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003455 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3456
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003457 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003458 perm = 0
3459 while perm < testGen.args.num_rand_permutations:
3460 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003461 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003462 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003463 (
3464 get_rand_params,
3465 get_upscale_downscale_params,
3466 get_aspect_ratio_resize_params,
3467 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003468 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003469 scale, offset, border = _rnd_param_fn()
3470 else:
3471 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003472
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003473 # Expand params for bounds-checking
3474 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3475 (offset_y, offset_x) = offset
3476 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003477
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003478 # Make sure output dimensions OH and OW are integers
3479 partial_output_y = (
3480 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3481 )
3482 partial_output_x = (
3483 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3484 )
3485 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003486 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003487 if (
3488 partial_output_y % scale_y_d == 0
3489 and partial_output_x % scale_x_d == 0
3490 ):
3491 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003492 if perm > 0:
3493 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003494 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003495 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003496 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003497 while partial_output_y % scale_y_d != 0:
3498 scale_y_d -= 1
3499 while partial_output_x % scale_x_d != 0:
3500 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003501 # Make sure we are still within max scaling
3502 if (
3503 scale_y_n / scale_y_d
3504 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3505 scale_x_n / scale_x_d
3506 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3507 # Skip the test as it is using too large a scaling factor
3508 if perm > 0:
3509 perm += 1
3510 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003511
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003512 output_y = partial_output_y // scale_y_d + 1
3513 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003514
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003515 if (
3516 output_y >= testGen.args.max_resize_output_dim
3517 or output_x >= testGen.args.max_resize_output_dim
3518 ) and error_name is None:
3519 # Skip positive test if output dim will be too high
3520 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003521 if not testGen.args.level8k or perm > 0:
3522 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003523 continue
3524
3525 if (
3526 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003527 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003528 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003529 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003530 ):
3531 # Output dimensions out of scope
3532 if error_name is not None and perm > 0:
3533 # As long as we have one ERROR_IF test, don't worry
3534 # about creating all the other permutations
3535 perm += 1
3536 continue
3537
3538 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3539 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003540 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003541 and output_y - scale_y_d < 1
3542 )
3543 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003544 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003545 and output_x - scale_x_d < 1
3546 )
3547 ):
3548 # Can't create a negative test with these params as it
3549 # will create invalid output size
3550 if perm > 0:
3551 perm += 1
3552 continue
3553
3554 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3555 offset = [offset_y, offset_x]
3556 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003557
3558 # Common for all data types
3559 if error_name is not None:
3560 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003561 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003562 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003563 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003564 outputDTypeNew,
3565 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003566 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 error_name,
3568 mode,
3569 dtype,
3570 shapeList,
3571 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003572 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003573 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003574 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003575 )
3576 else:
3577 outputDTypeNew = outputDType
3578
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003579 arg_to_append = (
3580 arg_str.format(
3581 "N" if mode == ResizeMode.NEAREST else "B",
3582 testGen.typeStr(outputDTypeNew),
3583 scale[0],
3584 scale[1],
3585 scale[2],
3586 scale[3],
3587 offset[0],
3588 offset[1],
3589 border[0],
3590 border[1],
3591 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003592 {
3593 "mode": mode,
3594 "scale": scale,
3595 "offset": offset,
3596 "border": border,
3597 "output_dtype": outputDTypeNew,
3598 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003599 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003600 if arg_to_append in arg_list:
3601 # Skip already generated test params
3602 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003603
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003604 # Valid permutation
3605 perm += 1
3606 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003607
3608 # Now add data generator types
3609 arg_list = TosaArgGen._add_data_generators(
3610 testGen,
3611 opName,
evacha019c96eef2024-02-07 11:21:55 +00003612 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003613 dtype,
3614 arg_list,
3615 error_name,
3616 )
3617 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003618 return arg_list
3619
3620 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003621 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003622 arg_list = []
3623
3624 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003625 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003626 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003627 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003628 # Make sure all slopes are within REQUIRE min/max 16-bit int
3629 for idx in range(len(table) - 1):
3630 slope = table[idx + 1] - table[idx]
3631 # Alter the next table entry to force the slope to be ok
3632 if slope > 32767:
3633 table[idx + 1] -= slope - 32767
3634 if slope < -32768:
3635 table[idx + 1] -= slope + 32768
3636 slope = table[idx + 1] - table[idx]
3637 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003638 arg_list.append(
3639 (
3640 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003641 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003642 )
3643 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003644 # Now add data generator types
3645 arg_list = TosaArgGen._add_data_generators(
3646 testGen,
3647 opName,
evacha019c96eef2024-02-07 11:21:55 +00003648 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003649 dtype,
3650 arg_list,
3651 error_name,
3652 )
3653 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003654 return arg_list
3655
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003656 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003657 # CondIf generates the condition values here.
3658 # Convert to tensors in the build function, along with the
3659 # then and else blocks
3660 arg_list = []
3661
3662 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003663 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003664
Jeremy Johnson587cc842024-02-08 11:45:44 +00003665 # Now add data generator types
3666 arg_list = TosaArgGen._add_data_generators(
3667 testGen,
3668 opName,
evacha019c96eef2024-02-07 11:21:55 +00003669 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003670 dtype,
3671 arg_list,
3672 error_name,
3673 )
3674 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003675 return arg_list
3676
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003677 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003678 # While loop: 0 iterations, 1, more than 1
3679 arg_list = []
3680
Jeremy Johnson587cc842024-02-08 11:45:44 +00003681 for iterations in [0, 1, 4]:
3682 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003683
Jeremy Johnson587cc842024-02-08 11:45:44 +00003684 # Now add data generator types
3685 arg_list = TosaArgGen._add_data_generators(
3686 testGen,
3687 opName,
evacha019c96eef2024-02-07 11:21:55 +00003688 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003689 dtype,
3690 arg_list,
3691 error_name,
3692 )
3693 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003694 return arg_list