blob: 83487a1a9ecf098e26c7beb5d9691bfeb2685d90 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import math
6
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
Jeremy Johnsonaf090182024-02-13 18:25:39 +000019logging.basicConfig()
20logger = logging.getLogger("tosa_verif_build_tests")
21
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010022
23class TosaQuantGen:
24 """QuantizedInfo random generator helper functions.
25
26 Specify with 'qgen': in the operator defintion.
27 """
28
29 def __init__(self):
30 pass
31
32 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010033 def getZeroPoint(rng, zeropoint, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010034
35 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010036 if zeropoint is not None:
37 return min(127, max(-128, zeropoint))
38 return rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 elif dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010040 if zeropoint is not None:
41 return min(255, max(0, zeropoint))
42 return rng.randInt(0, 256)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010043 elif error_name in [
44 ErrorIf.InputZeroPointNotZero,
45 ErrorIf.WeightZeroPointNotZero,
46 ErrorIf.OutputZeroPointNotZero,
47 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010048 zero_point = rng.randInt(-128, 128)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010049 if zero_point == 0:
50 zero_point = 1
51 return zero_point
52 return 0
53
54 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 def qgUnary(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010056 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000057 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
59 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000060 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010061 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000062 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010063 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
64 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000065 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000067 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010068 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
69 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000070 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071 return qinfo
72
73 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010074 def qgConv(rng, zeropoint, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
78 else:
79 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81
82 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000083 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010084 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0], error_name),
85 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000086 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000088 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
90 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1], error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[0]),
95 TosaQuantGen.getZeroPoint(rng, zeropoint, dtypeList[1]),
Eric Kunzeb5fabec2022-06-07 05:20:44 +000096 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 return qinfo
98
99 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100100 def qgMatmul(rng, zeropoint, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000102 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100103 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
104 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype, error_name),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000105 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000107 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
109 TosaQuantGen.getZeroPoint(rng, zeropoint, dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000110 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111 return qinfo
112
113 @staticmethod
114 def computeMultiplierAndShift(scaleFp, scale32):
115 # Derived from computeMultiplierAndShiftTosaScale32
116 # Provide a floating-point scaling factor and the scale32 parameter
117 # to compute the multiplier and shift
118
119 if scale32:
120 scaleBits = 31
121 else:
122 scaleBits = 15
123
124 m, shift = math.frexp(scaleFp)
125
126 if scaleFp < 0.0:
127 m = -m
128
129 multiplier = round(m * (1 << scaleBits))
130 assert multiplier <= (1 << scaleBits)
131
132 if multiplier == (1 << scaleBits):
133 multiplier = multiplier // 2
134 shift = shift + 1
135
136 shift = (-shift) + scaleBits
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000137 logger.debug(
138 f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140
141 # Adjust multiplier such that shift is in allowed value range.
142 if shift == 0:
143 multiplier = multiplier // 4
144 shift = shift + 2
145 elif shift == 1:
146 multiplier = multiplier // 2
147 shift = shift + 1
148 elif shift == 63:
149 multiplier = multiplier * 2
150 shift = shift - 1
151
152 assert multiplier <= (1 << scaleBits)
153 assert shift >= 2 and shift <= 62
154
155 return multiplier, shift
156
157
158class TosaTensorGen:
159 """Tensor generators create a shape list for the placeholder and const tensor
160 data operands for the operator.
161
162 The actual random data is generated separately for each test.
163 """
164
165 def __init__(self):
166 pass
167
168 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100169 def tgBasic(testGen, rng, op, rank, error_name=None):
170 pl, const = op["operands"]
171 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172
173 # Constrict the overall size of the shape when creating ERROR_IF tests
174 if error_name:
175 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
176
177 shape_list = []
178 for i in range(pl + const):
179 shape_list.append(shape.copy())
180
Luke Huttona4e48ca2023-02-22 11:53:48 +0000181 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 if error_name == ErrorIf.RankMismatch:
183 if rank == 1 and i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100184 shape = testGen.makeShape(rng, rank + rng.choice([1, 2, 3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185 elif i != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 shape = testGen.makeShape(rng, rank + rng.choice([-1, 1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187
188 return shape_list
189
190 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100191 def tgNHWC(testGen, rng, op, rank, error_name=None):
192 pl, const = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 if error_name != ErrorIf.WrongRank:
195 assert rank == 4
196
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000198 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100211 def tgGather(testGen, rng, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100219 values_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000220 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100221
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000222 N = values_shape[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100223 W = testGen.makeDimension(rng)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000224 indices_shape = [N, W]
225
226 shape_list = [values_shape, indices_shape]
227 return shape_list
228
229 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def tgScatter(testGen, rng, opName, rank, error_name=None):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000231 pl, const = opName["operands"]
232
233 assert pl == 3
234 assert const == 0
235 if error_name != ErrorIf.WrongRank:
236 assert rank == 3
237
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100238 values_in_shape = testGen.makeShape(rng, rank)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000239 values_in_shape = testGen.constrictBatchSize(values_in_shape)
240
241 N = values_in_shape[0]
242 K = values_in_shape[1]
243 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244
Jeremy Johnson194fe312023-12-07 14:17:57 +0000245 # Make sure W is not greater than K, as we can only write each output index
246 # once (having a W greater than K means that you have to repeat a K index)
247 W_min = min(testGen.args.tensor_shape_range[0], K)
248 W_max = min(testGen.args.tensor_shape_range[1], K)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100249 W = rng.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000251 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252
253 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000254 shape_list.append(values_in_shape)
255 shape_list.append([N, W]) # indices
256 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100257
258 return shape_list
259
260 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100261 def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
262 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100263 shape_list = []
264
265 # Choose one of the inputs to broadcast
266 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100267 bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
268 fuzz_idx = rng.randInt(0, rank)
Jerry Ge135c9552023-05-23 20:59:32 +0000269
Jeremy Johnson0a042992024-02-28 13:20:05 +0000270 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100271 shape_bcast = shape.copy()
272
Jerry Ge135c9552023-05-23 20:59:32 +0000273 # To test broadcasting, the chosen fuzz index dimension should not be 1
274 if shape_bcast[fuzz_idx] == 1:
275 shape_bcast[fuzz_idx] += 1
276
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100277 # If the chosen input, pick a random index to broadcast
278 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000279 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100280 # Add one rank to the shape (or more for rank of 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100281 extra_ranks = rng.choice([1, 2, 3]) if rank == 1 else 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100282 shape_bcast = np.concatenate(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100283 (shape_bcast, testGen.makeShape(rng, extra_ranks))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284 )
285 if rank != 1:
286 # Either keep the extra rank, or remove it
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100287 new_len = rng.choice([-2, len(shape_bcast)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000289 elif error_name == ErrorIf.BroadcastShapesMismatch:
290 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100291 else:
292 shape_bcast[fuzz_idx] = 1
293
294 shape_list.append(shape_bcast)
295
296 return shape_list
297
298 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100299 def tgBroadcastFuzz(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000300 pl, const = op["operands"]
301 num_shapes = pl + const
302 return TosaTensorGen._get_broadcast_shapes(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100303 testGen, rng, num_shapes, rank, error_name
Jeremy Johnson0a042992024-02-28 13:20:05 +0000304 )
305
306 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100307 def tgMul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000308 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 shape_list = TosaTensorGen._get_broadcast_shapes(
310 testGen, rng, 2, rank, error_name
311 )
Jeremy Johnson0a042992024-02-28 13:20:05 +0000312 # Add a single dimension tensor for shift
313 shape_list.append([1])
314 return shape_list
315
316 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100317 def tgConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318 pl, const = op["operands"]
319
320 if error_name != ErrorIf.WrongRank:
321 assert rank == 4
322
323 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100324 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000325 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326
327 # Constrict the overall size of the shape when creating ERROR_IF tests
328 if error_name:
329 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
330 ifm_shape, max_dim=24, max_items=10000
331 )
332
333 # Get the filter height/width from the operator parameters
334 filter_hw = op["filter"]
335
336 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100337 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # The filter dimensions are OHWI
340 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
341
342 # The bias is OC
343 bias_shape = np.asarray([ofm_depth])
344
345 return [ifm_shape, filter_shape, bias_shape]
346
347 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100348 def tgConv3D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 pl, const = op["operands"]
350
351 if error_name != ErrorIf.WrongRank:
352 assert rank == 5
353
354 # IFM dimensions are NDHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100355 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000356 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100357
358 # Constrict the overall size of the shape when creating ERROR_IF tests
359 if error_name:
360 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
361 ifm_shape, max_dim=24, max_items=10000
362 )
363
364 # Get the filter depth/height/width from the operator parameters
365 filter_dhw = op["filter"]
366
367 # Generate a random OFM channel
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100368 ofm_channel = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100369
370 # The filter dimensions are ODHWI
371 filter_shape = np.asarray(
372 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
373 )
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_channel])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100381 def tgTransposeConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386
387 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100388 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000389 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100390
391 # Constrict the overall size of the shape when creating ERROR_IF tests
392 if error_name:
393 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
394 ifm_shape, max_dim=24, max_items=10000
395 )
396
397 # Get the filter height/width from the operator parameters
398 filter_hw = op["filter"]
399
400 # Generate a random OFM depth
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100401 ofm_depth = testGen.makeDimension(rng)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100402
403 # The filter dimensions are OHWI
404 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
405
406 # The bias is OC
407 bias_shape = np.asarray([ofm_depth])
408
409 return [ifm_shape, filter_shape, bias_shape]
410
411 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100412 def tgDepthwiseConv2D(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100413 pl, const = op["operands"]
414
415 if error_name != ErrorIf.WrongRank:
416 assert rank == 4
417 assert pl == 1 and const == 2
418
419 # IFM dimensions are NHWC
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100420 ifm_shape = testGen.makeShape(rng, rank)
James Ward30124a82023-02-02 14:56:33 +0000421 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
426 ifm_shape, max_dim=24, max_items=10000
427 )
428
429 # Get the filter height/width from the operator parameters
430 # Filter is KH, HW, C, M
431 filter_hw = op["filter"]
432
433 # Generate a random OFM depth, but don't let it get too big because
434 # the output depth is M * C
435 filter_m = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100436 testGen.makeDimension(rng) % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100437 ) + 1
438
439 # The filter dimensions are HWCM
440 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
441
442 # The bias is M * C
443 bias_shape = np.asarray([ifm_shape[3] * filter_m])
444
445 return [ifm_shape, filter_shape, bias_shape]
446
447 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100448 def tgFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +0000449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 2 and const == 0
454
455 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100456 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton57287132023-02-06 14:54:18 +0000457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 inc_h = 2 if ifm_shape[1] == 1 else 1
469 inc_w = 2 if ifm_shape[2] == 1 else 1
470 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100471 selected_inc = rng.choice(inc_choices)
Luke Hutton57287132023-02-06 14:54:18 +0000472 ifm_shape[1] += selected_inc[0]
473 ifm_shape[2] += selected_inc[1]
474
475 ifm_shape = testGen.constrictBatchSize(ifm_shape)
476
477 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
478 if error_name == ErrorIf.FFTInputShapeMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100479 modify_shape = rng.choice([0, 1])
Luke Hutton57287132023-02-06 14:54:18 +0000480 # Only modify kernel (H, W)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100481 modify_dim = rng.choice([1, 2])
Luke Hutton57287132023-02-06 14:54:18 +0000482 ifm_shapes[modify_shape][modify_dim] *= 2
483
484 return [ifm_shapes[0], ifm_shapes[1]]
485
486 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100487 def tgRFFT2d(testGen, rng, op, rank, error_name=None):
Luke Hutton261b7b62023-01-10 14:50:31 +0000488 pl, const = op["operands"]
489
490 if error_name != ErrorIf.WrongRank:
491 assert rank == 3
492 assert pl == 1 and const == 0
493
494 # IFM dimensions are NHW
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100495 ifm_shape = testGen.makeShape(rng, rank)
Luke Hutton261b7b62023-01-10 14:50:31 +0000496
497 # Select nearest lower power of two from input height and width
498 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
499 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
500
501 # Constrict the overall size of the shape when creating ERROR_IF tests
502 if error_name:
503 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
504
505 # Generate an invalid kernel that is not a power of two
506 if error_name == ErrorIf.KernelNotPowerOfTwo:
507 # We must increment by 2 if current size is 1
508 inc_h = 2 if ifm_shape[1] == 1 else 1
509 inc_w = 2 if ifm_shape[2] == 1 else 1
510 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100511 selected_inc = rng.choice(inc_choices)
Luke Hutton261b7b62023-01-10 14:50:31 +0000512 ifm_shape[1] += selected_inc[0]
513 ifm_shape[2] += selected_inc[1]
514
James Ward30124a82023-02-02 14:56:33 +0000515 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000516
517 return [ifm_shape]
518
519 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100520 def tgFullyConnected(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100521 pl, const = op["operands"]
522
523 if error_name != ErrorIf.WrongRank:
524 assert rank == 2
525
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 input_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527
528 # Constrict the overall size of the shape when creating ERROR_IF tests
529 if error_name:
530 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
531
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100532 filter_oc = rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100533 low=testGen.args.tensor_shape_range[0],
534 high=testGen.args.tensor_shape_range[1],
535 size=1,
536 )[0]
537 filter_shape = np.asarray([filter_oc, input_shape[1]])
538
539 bias_shape = np.asarray([filter_oc])
540
541 return [input_shape, filter_shape, bias_shape]
542
543 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100544 def tgMatmul(testGen, rng, op, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100545 pl, const = op["operands"]
546
547 if error_name != ErrorIf.WrongRank:
548 assert rank == 3
549 assert pl == 2 and const == 0
550
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100551 a_shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100552
553 # Constrict the overall size of the shape when creating ERROR_IF tests
554 if error_name:
555 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
556
557 # Get a random number for b_oc even if target shape is defined
558 b_oc = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100559 rng.integers(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100560 low=testGen.args.tensor_shape_range[0],
561 high=testGen.args.tensor_shape_range[1],
562 size=1,
563 )
564 )[0]
565 # If N or H is large let b_oc be 1 to reduce output tensor size
566 if max(a_shape) > 1000:
567 b_oc = 1
568
569 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
570 return [a_shape, b_shape]
571
572 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100573 def tgConcat(testGen, rng, op, rank, error_name=None):
574 pl, const = op["operands"]
575 shape = testGen.makeShape(rng, rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100576
577 # Create extra tensors to concat.
578 # Take into account value of pl when getting maximum number of concats
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100579 num_tensors = rng.randInt(0, 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100580 shape_list = []
581 for i in range(pl + const + num_tensors):
582 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100583 remove = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100584 wrongShape = shape.copy()
585
586 if remove and len(shape) > 1:
587 wrongShape = wrongShape[1:]
588 else:
589 wrongShape = list(wrongShape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100590 wrongShape.append(rng.integers(1, 10))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100591
592 shape_list.append(wrongShape)
593 else:
594 shape_list.append(shape.copy())
595
596 return shape_list
597
598 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100599 def tgConcatConstInput(rng, shapeList, axis, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100600 if error_name in [
601 ErrorIf.AxisSmallerZero,
602 ErrorIf.AxisLargerRank,
603 ErrorIf.ConcatInputRankMismatch,
604 ]:
605 return shapeList
606
607 # Split concat shape along axis to allow for multiple const inputs
608 # without making too many large tensors
609 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
610 # If axis can't be split we still need to invalidate other dimensions
611 if error_name == ErrorIf.ConcatInputDimMismatch:
612 for shape in shapeList[1:]:
613 # Negative test shapeLists are created individually for each test,
614 # so no need to copy the shape before altering it.
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100615 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100616 return shapeList
617
618 # Create copy of shape we are going to split (so we don't alter shapeList)
619 shape = shapeList[0].copy()
620 # Add original shape as first input
621 new_shapeList = [shape.copy()]
622 length_on_axis = shape[axis]
623 remaining_length = length_on_axis
624 for i in range(len(shapeList) - 2):
625 # Calculate split on axis and remaining value
626 split_shape_val = int(shape[axis] / 2)
627 remaining_length = remaining_length - split_shape_val
628
629 # Append new shape, and set remaining shape
630 shape[axis] = split_shape_val
631 new_shapeList.append(shape.copy())
632
633 # invalidate dimensions
634 if error_name == ErrorIf.ConcatInputDimMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100635 shape[(axis + 1) % len(shape)] += rng.integers(5, 10)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100636 else:
637 shape[axis] = remaining_length
638
639 if i == len(shapeList) - 3:
640 new_shapeList.append(shape.copy())
641
642 return new_shapeList
643
644
645class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100646 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100647
648 def __init__(self):
649 pass
650
Jeremy Johnson1271c442023-09-05 11:39:26 +0100651 class TVGInfo:
652 """Enhanced tensor values information including data gen dict."""
653
654 def __init__(self, tensorList, dataGenDict):
655 self.tensorList = tensorList
656 self.dataGenDict = dataGenDict
657
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100658 # Default high value for random numbers
659 TVG_FLOAT_HIGH_VALUE = {
660 DType.FP32: (1 << 128) - (1 << (127 - 23)),
661 DType.FP16: (1 << 16) - (1 << (15 - 10)),
662 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000663 DType.FP8E4M3: 448,
664 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100665 }
666
Jeremy Johnson30476252023-11-20 16:15:30 +0000667 # Default lowest normal values for random numbers
668 TVG_FLOAT_LOW_VALUE = {
669 DType.FP32: np.exp2(-126),
670 DType.FP16: np.exp2(-14),
671 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000672 DType.FP8E4M3: np.exp2(-9),
673 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000674 }
675
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100676 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100677 def _get_data_range(rng, dtype, highValueLookup, lowValueLookup=None):
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # Return a tuple of (low,high) data range values for the given data
679 # type using a combination of per operator table limits, data limits
680 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000681 if dtype in highValueLookup:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100682 type_range = rng.dTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000683 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000684 if lowValueLookup is not None and dtype in lowValueLookup:
685 low_val = lowValueLookup[dtype]
686 else:
687 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000688 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000689 # respecting the default ranges if more/less than the low/high
690 # values
691 data_range = (
692 max(low_val, type_range[0]),
693 min(high_val, type_range[1]),
694 )
695 if data_range[0] > data_range[1]:
696 # Invalid data range from low to high created due to user
697 # constraints revert to using internal ranges as they are
698 # known to work
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000699 logger.info(
700 f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
701 )
Jeremy Johnson30476252023-11-20 16:15:30 +0000702 data_range = (low_val, high_val)
703 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000704 return None
705
706 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100707 def tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100708 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709 ):
710 # Variable inputs versus constants
711 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000712 if "p_count" in argsDict:
713 # Override for operators like CONCAT
714 pCount = argsDict["p_count"]
715 cCount = argsDict["c_count"]
716 assert pCount + cCount == len(
717 shapeList
718 ), "Placeholders & Constant tensors must match shapes list"
719
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000720 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100721
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100722 if (
723 error_name is not None
724 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100725 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100726 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000727 # Fall back to internal data gen when dealing with unsupported types or ops
728 data_range = argsDict["data_range"] if "data_range" in argsDict else None
729 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000730 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000731 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000732 if "data_range_list" in argsDict:
733 data_range = argsDict["data_range_list"][idx]["range"]
734 roundMode = (
735 "round" in argsDict["data_range_list"][idx]
736 and argsDict["data_range_list"][idx]["round"] is True
737 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000738 if data_range is not None and dtype not in (
739 DType.FP16,
740 DType.FP32,
741 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000742 DType.FP8E4M3,
743 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000744 ):
745 # Change from inclusive to exclusive range
746 data_range = (data_range[0], data_range[1] + 1)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000747
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100748 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000749 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000750 if dtype == DType.SHAPE:
751 arr = np.int64(argsDict["fixed_data"][idx])
752 elif dtype == DType.INT8:
753 arr = np.int8(argsDict["fixed_data"][idx])
Tai Ly6e1e2bc2024-03-01 20:59:32 +0000754 elif dtype == DType.INT16:
755 arr = np.int16(argsDict["fixed_data"][idx])
756 elif dtype == DType.INT32:
757 arr = np.int32(argsDict["fixed_data"][idx])
Jeremy Johnson0a042992024-02-28 13:20:05 +0000758 else:
759 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000760 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100761 arr = rng.randTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000762 if roundMode:
763 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000764 if idx < pCount:
765 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
766 else:
767 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100768
Jeremy Johnson1271c442023-09-05 11:39:26 +0100769 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
770
771 # Create data generator meta-data
772 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100773 tens_data = {
774 "version": "0.1",
775 "tensors": {},
776 }
777 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100778 for idx, shape in enumerate(shapeList):
779
780 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000781 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
782 tens_meta["generator"] = gtu.DataGenType(
783 gtu.DataGenType.FIXED_DATA
784 ).name
785 else:
786 tens_meta["generator"] = gtu.DataGenType(dg_type).name
787
Jeremy Johnson1271c442023-09-05 11:39:26 +0100788 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
789 tens_meta["shape"] = [int(i) for i in shape]
790 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100791 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100792
793 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100794 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100795 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100796 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100797
798 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
799 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000800 if (
801 tens_meta["generator"]
802 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
803 ):
804 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
805 tens_meta["fixed_data_info"] = info
806 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100807 info["rng_seed"] = rng.seed
Jeremy Johnson30476252023-11-20 16:15:30 +0000808
Won Jeon64e4bfe2024-01-18 06:31:55 +0000809 data_range = None
810 if "data_range_list" in argsDict:
811 data_range = argsDict["data_range_list"][idx]["range"]
812 if "round" in argsDict["data_range_list"][idx]:
813 info["round"] = argsDict["data_range_list"][idx]["round"]
814 elif "data_range" in argsDict:
815 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000816
Won Jeon64e4bfe2024-01-18 06:31:55 +0000817 if data_range is None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100818 data_range = rng.dTypeRange(dtypeList[idx], high_inclusive=True)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000819 info["range"] = [str(v) for v in data_range]
820 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100821 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
822 info = {}
823 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100824 info["ks"] = int(argsDict["ks"])
825 if "acc_type" in argsDict:
826 # Convert type number into JSON name
827 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
828 "json"
829 ]
830 if "kernel" in argsDict:
831 info["kernel"] = [int(k) for k in argsDict["kernel"]]
832 if "axis" in argsDict:
833 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100834 tens_meta["dot_product_info"] = info
evacha019c96eef2024-02-07 11:21:55 +0000835 elif dg_type == gtu.DataGenType.FULL_RANGE:
836 info = {}
837 info["start_val"] = int(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100838 rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
evacha019c96eef2024-02-07 11:21:55 +0000839 )
840 tens_meta["full_range_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100841 else:
842 # TODO - other data gen type
843 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100844
845 # Using the finished generate config meta data - generate the data if
846 # needed and assign a tensor name from the serializer
847
848 # Need to generate data when not lazy or for the bias tensor as we need
849 # to work out if the bias data is non-zero for compliance
850 if not testGen.args.lazy_data_gen or (
851 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
852 ):
853 # Give this tensor a temporary name until we get one from the serializer
854 temp_name = f"placeholder_{idx}"
855 dg_tens_meta[temp_name] = tens_meta
856 # Create data now using the temporary name to access meta details
857 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000858 if tens_meta["data_type"] == "SHAPE":
859 # Tensor type SHAPE and Numpy file type must be the same
860 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100861 # Remove the item as we will give it the correct name later
862 del dg_tens_meta[temp_name]
863
864 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
865 # The KS value used by compliance verification is altered when the
866 # bias data is non-zero
867 if max(abs(data)) > 0.0:
868 argsDict["ksb"] = argsDict["ks"] + 1
869
870 if testGen.args.lazy_data_gen:
871 data = None
872
873 if tens_meta["input_type"] == "VARIABLE":
874 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
875 else:
876 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
877
878 tens_ser_list.append(tens)
879 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100880 dg_tens_meta[tens.name] = tens_meta
881
Jeremy Johnson1271c442023-09-05 11:39:26 +0100882 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
883
884 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100885 def tvgNegate(
886 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
887 ):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100888 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000889 # Integer test
890 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100891 pCount, cCount = op["operands"]
892 assert (
893 pCount == 1 and cCount == 0
894 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100895 # Must create tensors with values within accumulator (int32) negatable
896 # range
897 max_val = (1 << 31) - 1
898 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100899 arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100900 rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000902 tens_ser_list = []
903 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
905 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000906 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000908 # ERROR_IF or floating point test
909 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100910 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100911 )
912
Jeremy Johnson30476252023-11-20 16:15:30 +0000913 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000914 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
915 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
916 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
917 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
918 }
919
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100921 def tvgAddSub(
922 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
923 ):
Won Jeon74342e52024-01-09 00:34:40 +0000924 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000925 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000927 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100928 pCount, cCount = op["operands"]
929 assert (
930 pCount == 2 and cCount == 0
931 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000932 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000933 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
934 data_range = testGen.args.tensor_shape_range
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100935 a_arr = rng.randTensor(shapeList[0], dtypeList[0], data_range)
936 b_arr = rng.randTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100937 if add:
938 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
939 else:
940 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
941
942 # Work out the saturation limits
943 max_i32 = (1 << 31) - 1
944 min_i32 = -(1 << 31)
945 max_arr = np.full(shapeList[1], max_i32)
946 min_arr = np.full(shapeList[1], min_i32)
947
948 # Find how much values exceed the maximum/minimums
949 sat_max_arr = np.maximum(res_arr - max_arr, 0)
950 sat_min_arr = np.minimum(res_arr - min_arr, 0)
951
952 if not add:
953 # Swap saturation values and negate values as we need to perform opposite operations
954 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
955
956 # Create new array of unsaturated values by clipping values as needed
957 b_unsat_arr = b_arr
958 if (sat_max_arr != 0).any():
959 # Clip values that cause saturation
960 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
961 # Reduce axes in unsaturated tensor to match original tensor
962 for axis, dim in enumerate(b_arr.shape):
963 if dim != b_unsat_arr.shape[axis]:
964 assert (
965 dim == 1
966 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
967 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
968
969 if (sat_min_arr != 0).any():
970 # Clip values that cause saturation
971 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
972 # Reduce axes in unsaturated tensor to match original tensor
973 for axis, dim in enumerate(b_arr.shape):
974 if dim != b_unsat_arr.shape[axis]:
975 assert (
976 dim == 1
977 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
978 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
979
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000980 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100981 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
982 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000983 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100984 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
985 )
986
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000987 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100988 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000989 # ERROR_IF or floating point test
990 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100991 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000992 )
993 if data_range:
994 argsDict["data_range"] = data_range
995
996 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100997 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100998 )
999
1000 @staticmethod
1001 def tvgCondIfWhileLoop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001002 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001003 ):
1004 if dtypeList[0] in (
1005 DType.INT32,
1006 DType.INT16,
1007 DType.INT8,
1008 ):
1009 # Limit input tensors with cond_if_binary or while_loop to stop
1010 # saturation of add/sub ops with int32 and keep all logical shift
1011 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +00001012 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001013 pCount, cCount = op["operands"]
1014 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +00001015 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001016 for idx, shape in enumerate(shapeList[:]):
1017 if dtypeList[0] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001018 arr = rng.randTensor(shapeList[idx], DType.INT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001020 arr = np.int32(rng.integers(low=0, high=32, size=shapeList[idx]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001022 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001023 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1024 )
1025 pRemain -= 1
1026 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001027 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001028 testGen.ser.addConst(shape, dtypeList[idx], arr)
1029 )
1030
Jeremy Johnson587cc842024-02-08 11:45:44 +00001031 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001032 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001033 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001034 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001035 )
1036
1037 @staticmethod
1038 def tvgArithmeticRightShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001039 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001040 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001041 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001042 pCount, cCount = op["operands"]
1043 # Force value of operand[1] to be within [0, num_bits]
1044 assert (
1045 pCount == 2 and cCount == 0
1046 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1047
Jeremy Johnson587cc842024-02-08 11:45:44 +00001048 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001049 for idx, shape in enumerate(shapeList[:]):
1050 if idx == 1:
1051 if dtypeList[idx] == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001052 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001053 elif dtypeList[idx] == DType.INT16:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001054 arr = np.int32(rng.integers(low=0, high=16, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055 elif dtypeList[idx] == DType.INT32:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001056 arr = np.int32(rng.integers(low=0, high=32, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 elif error_name == ErrorIf.WrongInputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001058 arr = np.int32(rng.integers(low=0, high=8, size=shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059 else:
1060 raise Exception("OpArithmeticRightShift: invalid input dtype")
1061 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001062 arr = rng.randTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001063 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064
Jeremy Johnson587cc842024-02-08 11:45:44 +00001065 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001066
1067 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001068 def tvgReshape(
1069 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1070 ):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001071 dtypeList[1] = DType.SHAPE
1072 shapeList[1] = [len(argsDict["new_shape"])]
1073 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1074 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1075
1076 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001077 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001078 )
1079
1080 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001081 def tvgRescale(
1082 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1083 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001084 scale32 = argsDict["scale"]
1085 multiplier_arr = argsDict["multiplier"]
1086 shift_arr = argsDict["shift"]
1087
1088 if scale32:
1089 dtypeList[1] = DType.INT32
1090 else:
1091 dtypeList[1] = DType.INT16
1092 shapeList[1] = [len(multiplier_arr)]
1093 dtypeList[2] = DType.INT8
1094 shapeList[2] = [len(shift_arr)]
1095 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1096 argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
1097
1098 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001099 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Ly6e1e2bc2024-03-01 20:59:32 +00001100 )
1101
1102 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001103 def tvgPad(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001104 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1105 pad_values = argsDict["pad"].flatten()
1106 dtypeList[1] = DType.SHAPE
1107 shapeList[1] = [len(pad_values)]
1108 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1109 argsDict["fixed_data"] = [None, pad_values]
1110
1111 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001112 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001113 )
1114
1115 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001116 def tvgSlice(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001117 dtypeList[1] = DType.SHAPE
1118 shapeList[1] = [len(argsDict["start"])]
1119 dtypeList[2] = DType.SHAPE
1120 shapeList[2] = [len(argsDict["size"])]
1121 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1122 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1123
1124 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001125 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001126 )
1127
1128 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001129 def tvgTile(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001130 dtypeList[1] = DType.SHAPE
1131 shapeList[1] = [len(argsDict["multiples"])]
1132 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1133
1134 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001135 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001136 )
1137
1138 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001139 def tvgSelect(
1140 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1141 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001142 # Set datatype of condition tensor to boolean
1143 dtypeList[0] = DType.BOOL
1144
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001145 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001146 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001147 )
1148
1149 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001150 def tvgIntDiv(
1151 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1152 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001153 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001154 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 pCount, cCount = op["operands"]
1156 assert (
1157 pCount == 2 and cCount == 0
1158 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1159
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001160 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161
1162 # Two invalid cases for Op.INTDIV:
1163 # 1. divisor == 0
1164 # 2. dividend == -(1<<31) and divisor == -1
1165 while True:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001166 dividend_arr = rng.randTensor(shapeList[0], dtypeList[0])
1167 divisor_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001168
1169 if (divisor_arr == 0).any():
1170 continue
1171
1172 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1173 continue
1174
1175 break
1176
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001177 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001178 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1179 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001180 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001181 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1182 )
1183
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001184 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001185 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001186 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001187 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001188 )
1189
Jeremy Johnson30476252023-11-20 16:15:30 +00001190 # Set the MUL data range to the square root of the largest value
1191 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001192 TVG_FLOAT_HIGH_VALUE_MUL = {
1193 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1194 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1195 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1196 }
1197
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001198 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001199 def tvgMul(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001200 if error_name is not None or dtypeList[0] in (
1201 DType.FP16,
1202 DType.BF16,
1203 DType.FP32,
1204 ):
1205 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001206 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001207 rng, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001208 )
1209 if data_range:
1210 argsDict["data_range"] = data_range
1211
Jeremy Johnson0a042992024-02-28 13:20:05 +00001212 if dtypeList[0] != DType.SHAPE:
1213 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1214 dtypeList[2] = DType.INT8
1215 shapeList[2] = [1]
1216 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1217 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1218
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001219 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001220 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001221 )
1222 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001223 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001224 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001225
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001226 tens_ser_list = []
1227
1228 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001229 if dtypeList[0] == DType.SHAPE:
1230 shift = 0
1231 else:
1232 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001233 if dtypeList[0] == DType.INT8:
1234 num_bits = 8
1235 elif dtypeList[0] == DType.INT16:
1236 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001237 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001238 num_bits = 32
1239 elif error_name == ErrorIf.WrongInputType:
1240 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001241 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001242 raise Exception(
1243 f"OpMul: invalid input dtype {gtu.DTYPE_ATTRIBUTES[dtypeList[0]]['str']}"
1244 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001245
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001246 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001247 if dtypeList[idx] == DType.SHAPE:
1248 low = testGen.args.tensor_shape_range[0]
1249 high = testGen.args.tensor_shape_range[1]
1250 else:
1251 low = -(2 ** (num_bits - 1))
1252 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001253
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001254 a_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[0]))
1255 b_arr = np.int32(rng.integers(low=low, high=high, size=shapeList[1]))
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001256
1257 i = 0
1258 while True:
1259
1260 a_arr_64 = a_arr.astype(np.int64)
1261 b_arr_64 = b_arr.astype(np.int64)
1262
1263 if shift > 0:
1264 rounding = 1 << (shift - 1)
1265 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001266 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001267 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001268
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001269 if (result_arr > -(2**31)).all() and (
1270 result_arr <= ((2**31) - 1)
1271 ).all():
1272 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001273
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001274 i = i + 1
1275 a_arr = a_arr // 2
1276 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001277
Won Jeon74342e52024-01-09 00:34:40 +00001278 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001279 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001280 tens_ser_list.append(
1281 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1282 )
1283 tens_ser_list.append(
1284 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1285 )
1286 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001287 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001288 tens_ser_list.append(
1289 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1290 )
1291 tens_ser_list.append(
1292 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1293 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001294 tens_ser_list.append(
1295 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1296 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001297
1298 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001299
1300 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001301 def tvgConcat(
1302 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1303 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001304 count = len(shapeList) - testGen.args.num_const_inputs_concat
1305 if count < 1:
1306 count = 1
1307 if testGen.args.num_const_inputs_concat == 0:
1308 count = len(shapeList)
1309
Won Jeon74342e52024-01-09 00:34:40 +00001310 op = testGen.TOSA_OP_LIST[opName]
1311 if op["op"] == Op.CONCAT_SHAPE:
1312 # Set the axis to 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001313 shapeList = TosaTensorGen.tgConcatConstInput(rng, shapeList, 0, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00001314 else:
1315 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001316 rng, shapeList, argsDict["axis"], error_name
Won Jeon74342e52024-01-09 00:34:40 +00001317 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001319 # Override default pCount/cCount for operator
1320 argsDict["p_count"] = count
1321 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001322
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001323 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001324 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001325 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001326
1327 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001328 def tvgLogicalShift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001329 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001330 ):
1331 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001332 pCount, cCount = op["operands"]
1333 assert (
1334 pCount == 2 and cCount == 0
1335 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001336 values_arr = rng.randTensor(shapeList[0], dtypeList[0])
1337 shift_arr = np.int32(rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001338 tens_ser_list = []
1339 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001340 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1341 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001342 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001343 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1344 )
1345
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001346 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001347
1348 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001349 def tvgEqual(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnsona0150012023-11-15 15:52:06 +00001350 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1351 # Integer
1352 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001353 pCount, cCount = op["operands"]
1354 assert (
1355 pCount == 2 and cCount == 0
1356 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001357
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001358 a_arr = rng.randTensor(shapeList[0], dtypeList[0])
1359 b_arr = rng.randTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001360
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001361 # Using random numbers means that it will be very unlikely that
1362 # there are any matching (equal) values, therefore force that
1363 # there are twice the number of matching values as the tensor rank
1364 for num in range(0, len(shapeList[0]) * 2):
1365 a_index = []
1366 b_index = []
1367 # Choose an index in each axis for the whole shape
1368 for axis in range(0, len(shapeList[0])):
1369 # Index can be up to the largest dimension in both shapes
1370 index = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001371 rng.integers(0, max(shapeList[0][axis], shapeList[1][axis]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001372 )
1373 # Reduce the index down to a shape's dim for broadcasting
1374 a_index.append(min(shapeList[0][axis] - 1, index))
1375 b_index.append(min(shapeList[1][axis] - 1, index))
1376
1377 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1378
Jeremy Johnsona0150012023-11-15 15:52:06 +00001379 tens_ser_list = []
1380 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001381 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1382 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001383 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001384 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1385 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001386 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001387 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001388 # ERROR_IF or floating point test
1389 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001390 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001391 )
1392
1393 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001394 def tvgReduceSum(
1395 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1396 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001397 dtype = dtypeList[0]
1398 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001399 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001400 pCount, cCount = op["operands"]
1401 assert (
1402 pCount == 1 and cCount == 0
1403 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1404 # Limit values so that the sum cannot exceed the range of an int32 during
1405 # summation of any axis
1406 range_val = int((1 << 31) / max(shapeList[0]))
1407 values_arr = np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001408 rng.integers(low=-range_val, high=range_val, size=shapeList[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001409 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001410 tens_ser_list = []
1411 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001412 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001413 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001414 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001415 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001417 if (
1418 error_name is None
1419 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1420 ):
1421 # Limit ranges for (non error & non compliance) tests by using
1422 # values that can be summed on any axis to not hit infinity
1423 highval_lookup = {
1424 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1425 / max(shapeList[0])
1426 }
1427 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001428 rng, dtype, highval_lookup
Jeremy Johnson30476252023-11-20 16:15:30 +00001429 )
1430 assert data_range is not None
1431 argsDict["data_range"] = data_range
1432
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001433 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001434 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001435 )
1436
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001437 @staticmethod
1438 def tvgReduceProduct(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001439 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001440 ):
1441 dtype = dtypeList[0]
1442 if error_name is None:
1443 # Limit ranges for (non error) tests by using
1444 # values that can be multiplied on any axis to not hit infinity
1445 highval_lookup = {
1446 dtype: math.pow(
1447 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1448 1 / max(shapeList[0]),
1449 )
1450 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001451 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001452 assert data_range is not None
1453 argsDict["data_range"] = data_range
1454
1455 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001456 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001457 )
1458
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001459 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001460 def tvgResize(
1461 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1462 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001463 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001464 rng,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001465 dtypeList[0],
1466 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1467 )
1468 if data_range:
1469 argsDict["data_range"] = data_range
1470 # Needed for compliance
1471 argsDict["max_abs_value"] = data_range[1]
1472
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001473 scale_values = argsDict["scale"]
1474 offset_values = argsDict["offset"]
1475 border_values = argsDict["border"]
1476 dtypeList[1] = DType.SHAPE
1477 dtypeList[2] = DType.SHAPE
1478 dtypeList[3] = DType.SHAPE
1479 shapeList[1] = [len(scale_values)]
1480 shapeList[2] = [len(offset_values)]
1481 shapeList[3] = [len(border_values)]
1482 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1483
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001484 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001485 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001486 )
1487
Jeremy Johnson30476252023-11-20 16:15:30 +00001488 # Set the POW exponent high data range
1489 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1490 DType.FP32: 10.0,
1491 DType.FP16: 10.0,
1492 DType.BF16: 10.0,
1493 }
1494 # POW highest base value (within a safe margin of error) that can be raised
1495 # to +ve exponent that doesn't become Infinity
1496 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1497 DType.FP32: math.floor(
1498 math.pow(
1499 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1500 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1501 )
1502 ),
1503 DType.FP16: math.floor(
1504 math.pow(
1505 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1506 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1507 )
1508 ),
1509 DType.BF16: math.floor(
1510 math.pow(
1511 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1512 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1513 )
1514 ),
1515 }
1516 # POW lowest base value (within a safe margin of error) that can be raised
1517 # to -ve exponent that doesn't become Infinity
1518 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1519 DType.FP32: math.ceil(
1520 math.pow(
1521 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1522 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1523 )
1524 * 1000
1525 )
1526 / 1000,
1527 DType.FP16: math.ceil(
1528 math.pow(
1529 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1530 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1531 )
1532 * 1000
1533 )
1534 / 1000,
1535 DType.BF16: math.ceil(
1536 math.pow(
1537 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1538 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1539 )
1540 * 1000
1541 )
1542 / 1000,
1543 }
1544
1545 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001546 def tvgPow(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001547 if error_name is not None:
1548 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001549 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001550 )
1551 dtype = dtypeList[0]
1552 # Different ranges for POW
1553 test_set = argsDict["s"]
1554 if test_set == 0:
1555 # Positive base with fractional exponent
1556 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001557 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001558 dtype,
1559 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1560 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1561 )
1562 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001563 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001564 )
1565 exp_round = False
1566 else:
1567 # Integer exponent
1568 exp_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001569 rng, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
Jeremy Johnson30476252023-11-20 16:15:30 +00001570 )
1571 exp_round = True
1572 if test_set == 1:
1573 # Positive base
1574 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001575 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001576 dtype,
1577 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1578 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1579 )
1580 else:
1581 assert test_set == 2
1582 # Negative base
1583 # Supply new look up tables with negative values
1584 base_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001585 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001586 dtype,
1587 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1588 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1589 )
1590
1591 data_range_list = (
1592 {
1593 "range": base_range,
1594 },
1595 {
1596 "range": exp_range,
1597 "round": exp_round,
1598 },
1599 )
1600 argsDict["data_range_list"] = data_range_list
1601 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001602 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001603 )
1604
1605 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001606 def tvgLogRsqrt(
1607 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1608 ):
Jeremy Johnson30476252023-11-20 16:15:30 +00001609 # LOG & RSQRT data range from lowest expressible positive number to
1610 # largest to avoid NaNs
1611 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001612 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001613 dtypeList[0],
1614 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1615 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1616 )
1617 if data_range:
1618 argsDict["data_range"] = data_range
1619
1620 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001621 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001622 )
1623
1624 # Set the EXP data range to the log of the largest to smallest values
1625 # to avoid infinities or making the result zero
1626 TVG_FLOAT_HIGH_VALUE_EXP = {
1627 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1628 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1629 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1630 }
1631 TVG_FLOAT_LOW_VALUE_EXP = {
1632 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1633 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1634 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1635 }
1636
1637 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001638 def tvgExp(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001639 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001640 rng,
Jeremy Johnson30476252023-11-20 16:15:30 +00001641 dtypeList[0],
1642 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1643 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1644 )
1645 if data_range:
1646 argsDict["data_range"] = data_range
1647
1648 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001649 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001650 )
1651
1652 @staticmethod
1653 def tvgFullyConnected(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001654 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson30476252023-11-20 16:15:30 +00001655 ):
1656 dtype = dtypeList[0]
1657 if (
1658 error_name is None
1659 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001660 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001661 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001662 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001663 # Limit ranges for (non error & non compliance) FP tests by using
1664 # values that can be multiplied on any axis to not hit infinity/NaN
1665 IC = shapeList[0][1]
1666 highval_lookup = {
1667 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1668 }
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001669 data_range = TosaTensorValuesGen._get_data_range(rng, dtype, highval_lookup)
Jeremy Johnson30476252023-11-20 16:15:30 +00001670 assert data_range is not None
1671 argsDict["data_range"] = data_range
1672
1673 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001674 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson30476252023-11-20 16:15:30 +00001675 )
1676
Jeremy Johnson708da822023-11-15 16:25:45 +00001677 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001678 def tvgCast(testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson708da822023-11-15 16:25:45 +00001679 in_dtype = dtypeList[0]
1680 out_dtype = argsDict["out_type"]
1681 # Create look up to limit input tensor to output type maximums to avoid
1682 # FP infinities and saturation of integers
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001683 out_range = rng.dTypeRange(out_dtype, high_inclusive=True)
Jeremy Johnson708da822023-11-15 16:25:45 +00001684 highval_lookup = {in_dtype: out_range[1]}
1685 data_range = TosaTensorValuesGen._get_data_range(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001686 rng,
Jeremy Johnson708da822023-11-15 16:25:45 +00001687 in_dtype,
1688 highval_lookup,
1689 )
1690
1691 assert data_range is not None
1692 argsDict["data_range"] = data_range
1693
1694 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001695 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson708da822023-11-15 16:25:45 +00001696 )
1697
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001698 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001699 def tvgGather(
1700 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1701 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001702 K = shapeList[0][1]
1703
1704 # Fix the type of the indices tensor
1705 dtypeList[1] = DType.INT32
1706
1707 dtype = dtypeList[0]
1708 if not gtu.dtypeIsSupportedByCompliance(dtype):
1709 # Test unsupported by data generator
1710 op = testGen.TOSA_OP_LIST[opName]
1711 pCount, cCount = op["operands"]
1712 assert (
1713 pCount == 2 and cCount == 0
1714 ), "Op.GATHER must have 2 placeholders, 0 consts"
1715
1716 tens_ser_list = []
1717 for idx, shape in enumerate(shapeList):
1718 dtype = dtypeList[idx]
1719 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001720 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001721 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1722 else:
1723 # Limit data range of indices tensor upto K (exclusive)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001724 arr = rng.randTensor(shape, dtype, (0, K))
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001725 # To match old functionality - create indices as CONST
1726 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1727
1728 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1729
1730 else:
1731 # ERROR_IF or floating point test
1732 # Use inclusive values upto index K for indices tensor
1733 data_range_list = (
1734 {"range": None},
1735 {"range": (0, K - 1)},
1736 )
1737 argsDict["data_range_list"] = data_range_list
1738
1739 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001740 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001741 )
1742
1743 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001744 def tvgScatter(
1745 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name=None
1746 ):
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001747 K = shapeList[0][1]
1748 W = shapeList[2][1]
1749
1750 # Work out an indices tensor here with data that doesn't exceed the
1751 # dimension K of the values_in tensor and does NOT repeat the same K
1752 # location as needed by the spec:
1753 # "It is not permitted to repeat the same output index within a single
1754 # SCATTER operation and so each output index occurs at most once."
1755 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1756
1757 # Fix the type of the indices tensor
1758 dtypeList[1] = DType.INT32
1759
1760 dtype = dtypeList[0]
1761 if not gtu.dtypeIsSupportedByCompliance(dtype):
1762 # Test unsupported by data generator
1763 op = testGen.TOSA_OP_LIST[opName]
1764 pCount, cCount = op["operands"]
1765 assert (
1766 pCount == 3 and cCount == 0
1767 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1768
1769 tens_ser_list = []
1770 for idx, shape in enumerate(shapeList):
1771 dtype = dtypeList[idx]
1772 if idx != 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001773 arr = rng.randTensor(shape, dtype)
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001774 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1775 else:
1776 # Create the indices array
1777 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1778 arr = []
1779 for n in range(shape[0]):
1780 # Get a shuffled list of output indices (0 to K-1) and
1781 # limit length to W
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001782 arr.append(rng.permutation(K)[:W])
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001783 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1784 # To match old functionality - create indices as CONST
1785 tens_ser_list.append(
1786 testGen.ser.addConst(shape, dtype, indices_arr)
1787 )
1788
1789 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1790
1791 else:
1792 # ERROR_IF or floating point test
1793 # Use inclusive values upto index K for indices tensor
1794 data_range_list = (
1795 {"range": None},
1796 {"range": (0, K - 1)},
1797 {"range": None},
1798 )
1799 argsDict["data_range_list"] = data_range_list
1800
1801 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001802 testGen, rng, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001803 )
1804
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001805
1806class TosaArgGen:
1807 """Argument generators create exhaustive or random lists of attributes for
1808 operators that take attributes or other parameters.
1809
1810 The return value is a list of (descriptive_name, [arglist]) tuples where
1811 the descriptive_name is appended to the test name and the arglist is expanded
1812 as arguments to the operator build function.
1813 """
1814
1815 def __init__(self):
1816 pass
1817
1818 @staticmethod
evacha019c96eef2024-02-07 11:21:55 +00001819 def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001820 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001821 if (
1822 error_name is None
1823 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1824 and gtu.dtypeIsSupportedByCompliance(dtype)
1825 ):
Tai Ly60dc48c2024-03-08 22:19:41 +00001826 if gtu.dtypeIsFloat(dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001827 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1828 else:
1829 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1830 else:
1831 # Error test or No data generator types listed - assume random
1832 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1833
1834 # Expand arg list with other data generator types
1835 new_arg_list = []
1836 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001837 for arg_str, args_dict in arg_list:
evacha019c96eef2024-02-07 11:21:55 +00001838
1839 if dg_type == gtu.DataGenType.FULL_RANGE:
1840 tensor_size = gtu.product(shapeList[0])
1841 if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
1842 # Large enough tensor data size for full range, add a single test
1843 num_test_sets = 0
1844 else:
1845 # Not enough data size for full range of values, revert to random numbers
1846 dg_type = gtu.DataGenType.PSEUDO_RANDOM
1847
Jeremy Johnson1271c442023-09-05 11:39:26 +01001848 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001849 if error_name is None:
1850 num_test_sets = (
1851 args_dict["num_test_sets"]
1852 if "num_test_sets" in args_dict
1853 else 0
1854 )
1855 else:
evacha019c96eef2024-02-07 11:21:55 +00001856 # Add single test for pseudo random
Jeremy Johnson30476252023-11-20 16:15:30 +00001857 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001858
1859 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1860 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001861 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001862 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001863 shape_info = (
1864 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1865 if "shape" in args_dict
1866 else ""
1867 )
Jeremy Johnsonaf090182024-02-13 18:25:39 +00001868 logger.info(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001869 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001870 )
1871 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001872 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001873 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001874 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001875
Jeremy Johnson30476252023-11-20 16:15:30 +00001876 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1877
1878 if num_test_sets > 0:
1879 for s in range(0, num_test_sets):
evacha019c96eef2024-02-07 11:21:55 +00001880 set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
1881 set_args_dict = args_dict.copy()
1882 set_args_dict["s"] = s
1883 set_args_dict["dg_type"] = dg_type
1884 new_arg_list.append((set_arg_str, set_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001885 else:
1886 # Default is a single test
evacha019c96eef2024-02-07 11:21:55 +00001887 new_args_dict = args_dict.copy()
1888 new_args_dict["dg_type"] = dg_type
1889 new_arg_list.append((arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001890
1891 return new_arg_list
1892
1893 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001894 def agNone(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001895 """A trivial argument generator for operators that don't take any
1896 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001897 arg_list = TosaArgGen._add_data_generators(
1898 testGen,
1899 opName,
evacha019c96eef2024-02-07 11:21:55 +00001900 shapeList,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001901 dtype,
1902 [("", {})],
1903 error_name,
1904 )
1905 # Return list of tuples: (arg_str, args_dict)
1906 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001907
1908 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001909 def agPow(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001910 """Pow operator needs different test sets to cover random numbers
1911 without creating NaNs or Infs"""
1912 arg_list = TosaArgGen._add_data_generators(
1913 testGen,
1914 opName,
evacha019c96eef2024-02-07 11:21:55 +00001915 shapeList,
Jeremy Johnson30476252023-11-20 16:15:30 +00001916 dtype,
1917 [("", {"num_test_sets": 3})],
1918 error_name,
1919 )
1920 # Return list of tuples: (arg_str, args_dict)
1921 return arg_list
1922
1923 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001924 def agAxis(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001925 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001926 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001927 shape = shapeList[0]
1928
1929 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001930 # Set too small axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001931 axes = [rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001932 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001933 # Set too large axis
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001934 axes = [rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001935 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001936 # Create tests for each dimension
1937 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001938
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001939 opid = testGen.TOSA_OP_LIST[opName]["op"]
1940
1941 for a in axes:
1942 args_dict = {"axis": int(a)}
1943 if opid == Op.REDUCE_SUM:
Jeremy Johnsone52c0a32024-03-11 09:58:24 +00001944 output_shape = shape.copy()
1945 if error_name is None:
1946 # It only matters that we calculate the dot_products correctly
1947 # for non error_if tests as they should never be run
1948 output_shape[a] = 1
1949 args_dict["dot_products"] = gtu.product(output_shape)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001950 args_dict["shape"] = shape
1951 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1952 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1953
1954 arg_list.append(("axis{}".format(a), args_dict))
1955
1956 arg_list = TosaArgGen._add_data_generators(
1957 testGen,
1958 opName,
evacha019c96eef2024-02-07 11:21:55 +00001959 shapeList,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001960 dtype,
1961 arg_list,
1962 error_name,
1963 )
1964 # Return list of tuples: (arg_str, args_dict)
1965 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001966
1967 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001968 def _calculate_sparsity(num_tests, sparsity_factor):
1969 sparsity = num_tests // sparsity_factor + 1
1970 # If there are only a small number of tests, just select them all
1971 if sparsity < 13:
1972 sparsity = 1
1973 # To get a variety of parameter combinations sparsity should not be a
1974 # multiple of 2, 3 or 5
1975 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1976 sparsity += 1
1977 return sparsity
1978
1979 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001980 def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001981 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001982 arg_list = []
1983
Jeremy Johnson0c716862023-04-13 17:18:19 +01001984 if testGen.args.level8k and error_name is not None:
1985 # Don't produce negative large tests
1986 return arg_list
1987
1988 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001989 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001990 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001991 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001992
Jeremy Johnson1271c442023-09-05 11:39:26 +01001993 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001994
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001995 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001996 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001997 depthwise = opName.startswith("depthwise")
1998
1999 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01002000 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002001 if error_name != ErrorIf.WrongRank:
2002 assert len(ifm_shape) == rank
2003 assert len(filter_shape) == rank
2004
Jeremy Johnson0c716862023-04-13 17:18:19 +01002005 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002006 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002007 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01002008 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002009 # compliance size - KS
2010 k_size = gtu.product(k_shape)
2011 if not depthwise:
2012 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002013
Jeremy Johnson0c716862023-04-13 17:18:19 +01002014 if not testGen.args.level8k:
2015 # Generate comprehensive argument lists
2016 # - except for named errors, which use specific invalid value(s)
2017 if error_name == ErrorIf.PadSmallerZero:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002018 p_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002019 else:
2020 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
2021 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
2022 if error_name == ErrorIf.StrideSmallerOne:
2023 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002024 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002025 else:
2026 # Stride must be greater than 1 to force non-integer error
2027 startStride = (
2028 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002029 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002030 s_vals = [
2031 x for x in range(startStride, testGen.args.max_conv_stride + 1)
2032 ]
2033 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
2034 if error_name == ErrorIf.DilationSmallerOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002035 d_vals = [rng.choice(range(-5, 1))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002036 else:
2037 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
2038 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002039
Jeremy Johnson0c716862023-04-13 17:18:19 +01002040 if not error_name and testGen.args.oversize:
2041 # add some oversize argument values
2042 if max(ifm_shape) < 64:
2043 bigPadding = 9
2044 paddings.update(
2045 {
2046 x
2047 for x in itertools.product(
2048 *([[0, bigPadding]] * (k_rank * 2))
2049 )
2050 }
2051 )
2052 bigStride = 8
2053 strides.update(
2054 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
2055 )
2056 bigDilation = 7
2057 dilations.update(
2058 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
2059 )
2060 max_dim_size = None
2061
2062 # There are too many parameter combinations, so generate them sparsely,
2063 # very sparse for negative tests
2064 sparsity_factor = 2 if error_name else 120
2065 sparsity = TosaArgGen._calculate_sparsity(
2066 len(paddings) * len(strides) * len(dilations), sparsity_factor
2067 )
2068 else:
2069 # Only test 8k levels boundaries
2070 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2071 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2072 bigPadding = bigKernel
2073
2074 dilation_shape = [1] * k_rank
2075 pad_shape = [0] * k_rank * 2
2076 if conv3d:
2077 # Small stride apart from for big kernel (see below) to keep
2078 # tensor size/calculation small
2079 stride_shape = [1] * k_rank
2080 for idx in range(k_rank):
2081 pad_offset = idx * 2
2082 if k_shape[idx] == bigKernel:
2083 # Padding shape needs to account for tensor shape
2084 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2085 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2086 # Big stride to reduce output size
2087 stride_shape[idx] = bigKernel
2088 else:
2089 # Account for kernel size
2090 pad_shape[pad_offset] = k_shape[idx] - 1
2091 else:
2092 # Always have a large stride with extra padding and dilation to keep
2093 # tensor calculation reasonable
2094 stride_shape = [bigKernel] * k_rank
2095 for idx in range(k_rank):
2096 # Dilation shape must account for kernel size
2097 dilation_shape[idx] = bigKernel // k_shape[idx]
2098 # Padding shape needs to accommodate tensor/kernel & dilation
2099 pad_offset = idx * 2
2100 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2101 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2102
2103 strides = {tuple(stride_shape)}
2104 dilations = {tuple(dilation_shape)}
2105 paddings = {tuple(pad_shape)}
2106 # Create a limit for the output dimensions size
2107 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2108
2109 # Currently allow all combinations that are reasonable size
2110 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002111
2112 n = 0
2113 for s in sorted(list(strides)):
2114 for p in sorted(list(paddings)):
2115 for d in sorted(list(dilations)):
2116 if (
2117 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01002118 # the padded shape must exceed the dilation * kernel to get a positive
2119 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002120 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
2121 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01002122 and (
2123 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01002124 or (
2125 (ifm_shape[3] - 1 + p[4] + p[5])
2126 > d[2] * (k_shape[2] - 1)
2127 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01002128 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002129 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002130 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01002131 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002132 for index in range(k_rank):
2133 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002134 partial = (
2135 ifm_shape[index + 1]
2136 - 1
2137 + p[pad_offset]
2138 + p[pad_offset + 1]
2139 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002140 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002141 remainders.append(partial % s[index])
2142 outputs.append((partial // s[index]) + 1)
2143
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002144 if (
2145 # the parameters must produce integer exact output
2146 error_name != ErrorIf.ConvOutputShapeNonInteger
2147 and max(remainders) == 0
2148 ) or (
2149 error_name == ErrorIf.ConvOutputShapeNonInteger
2150 and max(remainders) > 0
2151 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002152 if (
2153 max_dim_size is not None
2154 and max(outputs) >= max_dim_size
2155 ):
2156 # Test will consume too much memory - skip it
2157 continue
2158
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002159 # Compliance - number of dot product calculations
2160 if depthwise:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002161 # N*OH*OW*C*M
2162 dots = gtu.product(
2163 (ifm_shape[0], *outputs, *filter_shape[2:])
2164 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002165 else:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002166 # N*OH*OW*OC or N*OD*OH*OW*OC
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002167 dots = gtu.product(
2168 (ifm_shape[0], *outputs, filter_shape[0])
2169 )
2170 args_dict = {
2171 "acc_type": accum_dtype,
2172 "stride": s,
2173 "pad": p,
2174 "dilation": d,
2175 "kernel": k_shape,
2176 "ks": k_size,
2177 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002178 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002179 }
2180
Jeremy Johnson0c716862023-04-13 17:18:19 +01002181 # Support for larger values than 9 needs different delimiter
2182 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002183 arg_list.append(
2184 (
James Ward8b390432022-08-12 20:48:56 +01002185 "acc{}_st{}_pad{}_dilat{}".format(
2186 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002187 delim.join([str(x) for x in s]),
2188 delim.join([str(x) for x in p]),
2189 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002190 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002191 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002192 )
2193 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002194 n += 1
2195
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002196 arg_list = TosaArgGen._add_data_generators(
2197 testGen,
2198 opName,
evacha019c96eef2024-02-07 11:21:55 +00002199 shapeList,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002200 dtypes[0],
2201 arg_list,
2202 error_name,
2203 )
2204 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002205 return arg_list
2206
2207 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002208 def agFullyConnected(testGen, rng, opName, shapeList, dtypes, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002209
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002210 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002211 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002212
2213 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002214 accum_dtype = gtu.get_wrong_output_type(opName, rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002215 elif error_name == ErrorIf.WrongInputType:
2216 # Pick some potentially correct output dtype if input type is incorrect
2217 accum_dtype = DType.INT32
2218 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002219 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002220
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002221 # Set up compliance info
2222 args_dict = {
2223 "acc_type": accum_dtype,
2224 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2225 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2226 "shape": shapeList[0],
2227 }
2228
2229 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2230
2231 arg_list = TosaArgGen._add_data_generators(
2232 testGen,
2233 opName,
evacha019c96eef2024-02-07 11:21:55 +00002234 shapeList,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002235 input_dtype,
2236 arg_list,
2237 error_name,
2238 )
2239 # Return list of tuples: (arg_str, args_dict)
2240 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002241
2242 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002243 def agMatMul(testGen, rng, opName, shapeList, dtype, error_name=None):
James Ward8b390432022-08-12 20:48:56 +01002244 # Get valid accumulate type(s)
2245 if dtype == DType.INT8:
2246 accum_dtypes = [DType.INT32]
2247 elif dtype == DType.INT16:
2248 accum_dtypes = [DType.INT48]
2249 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002250 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002251 elif dtype == DType.BF16:
2252 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002253 elif dtype == DType.FP32:
2254 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002255 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2256 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002257 elif error_name is None:
2258 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2259
2260 if error_name == ErrorIf.WrongOutputType:
2261 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002262 accum_dtypes = [gtu.get_wrong_output_type(opName, rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002263 elif error_name == ErrorIf.WrongInputType:
2264 # Pick some potentially correct output dtype if input type is incorrect
2265 accum_dtypes = [DType.INT32]
2266
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002267 # Set up compliance info
2268 args_dict = {
2269 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2270 # Set dot_products = N*H*W
2271 "dot_products": gtu.product(
2272 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2273 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002274 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002275 }
2276
2277 # Create arg tuple of string and dict
2278 arg_list = []
2279 for a in accum_dtypes:
2280 d = args_dict.copy()
2281 d["acc_type"] = a
2282 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002283
2284 arg_list = TosaArgGen._add_data_generators(
2285 testGen,
2286 opName,
evacha019c96eef2024-02-07 11:21:55 +00002287 shapeList,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002288 dtype,
2289 arg_list,
2290 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002291 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002292 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002293 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002294
2295 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002296 def agTransposeConv2D(testGen, rng, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002297 arg_list = []
2298
Jeremy Johnson0c716862023-04-13 17:18:19 +01002299 if testGen.args.level8k and error_name is not None:
2300 # Don't produce negative large tests
2301 return arg_list
2302
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002303 ifm_shape = shapeList[0]
2304 filter_shape = shapeList[1]
2305
Jeremy Johnson1271c442023-09-05 11:39:26 +01002306 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002307
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002308 # Must be rank 4
2309 if error_name != ErrorIf.WrongRank:
2310 assert len(ifm_shape) == 4
2311 assert len(filter_shape) == 4
2312
Jeremy Johnson0c716862023-04-13 17:18:19 +01002313 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002314 # compliance size - KS
2315 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002316
Jeremy Johnson0c716862023-04-13 17:18:19 +01002317 if not testGen.args.level8k:
2318 # Generate comprehensive argument lists
2319 # - except for named errors, which use specific invalid value(s)
2320 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2321 if error_name == ErrorIf.PadLargerEqualKernel:
2322 max_filter_size = -max(k_shape[0], k_shape[1])
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002323 p_vals = [rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002324 else:
2325 p_vals = [
2326 x
2327 for x in range(
2328 smallest_padding_size, testGen.args.max_conv_padding + 1
2329 )
2330 ]
2331 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2332 if error_name == ErrorIf.StrideSmallerOne:
2333 # Can't use stride=0, as it is used to derive output shape, as a divisor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002334 s_vals = [rng.choice(range(-5, 0))]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002335 else:
2336 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2337 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002338
Jeremy Johnson0c716862023-04-13 17:18:19 +01002339 if not error_name and testGen.args.oversize:
2340 # add some oversize argument values
2341 if max(ifm_shape) < 64:
2342 bigPadding = 9
2343 paddings.update(
2344 {
2345 x
2346 for x in itertools.product(
2347 *([[smallest_padding_size, bigPadding]] * 4)
2348 )
2349 }
2350 )
2351 bigStride = 8
2352 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2353
2354 # There are too many parameter combinations, so generate them sparsely,
2355 # very sparse for negative tests
2356 sparsity_factor = 2 if error_name else 10
2357 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2358 # If there are only a small number of tests, just select them all
2359 if sparsity < 13:
2360 sparsity = 1
2361 # To get a variety of parameter combinations sparsity should not be a
2362 # multiple of 2, 3 or 5
2363 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2364 sparsity += 1
2365 else:
2366 # Only test 8k levels boundaries
2367 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2368 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2369 bigPadding = bigKernel
2370
2371 pad_shape = [0] * (len(k_shape) * 2)
2372 stride_shape = [1] * len(k_shape)
2373 # The point at which input dimension combined with the stride will
2374 # create large output sizes!
2375 LARGE_SIZE = 2
2376 for idx in range(len(k_shape)):
2377 pad_offset = idx * 2
2378 if k_shape[idx] == bigKernel:
2379 # Set large stride
2380 stride_shape[idx] = bigKernel
2381 # Use negative output padding to reduce shape size
2382 pad_shape[pad_offset] = -(bigPadding - 1)
2383 if ifm_shape[idx + 1] > LARGE_SIZE:
2384 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2385 else:
2386 # The other dimension should be the bigKernel
2387 alt_idx = 1 - idx
2388 if (
2389 k_shape[alt_idx] == bigKernel
2390 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2391 ):
2392 # As the input is small, the large stride won't
2393 # affect the output so we can add some padding
2394 pad_shape[pad_offset + 1] = bigPadding
2395
2396 strides = {tuple(stride_shape)}
2397 paddings = {tuple(pad_shape)}
2398
2399 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002400 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002401
2402 n = 0
2403 for s in sorted(list(strides)):
2404 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002405 if n % sparsity == 0:
2406 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002407 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2408 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002409 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002410
Jeremy Johnson95a67102024-01-10 14:16:39 +00002411 # N*OH*OW*OC
2412 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2413 args_dict = {
2414 "acc_type": accum_dtype,
2415 "stride": s,
2416 "pad": p,
2417 "kernel": k_shape,
2418 "ks": k_size,
2419 "dot_products": dots,
2420 "shape": ifm_shape,
2421 "out_shape": os,
2422 }
2423
Jeremy Johnson0c716862023-04-13 17:18:19 +01002424 # Support for larger values than 9 needs different delimiter
2425 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002426 arg_list.append(
2427 (
James Ward8b390432022-08-12 20:48:56 +01002428 "acc{}_st{}_pad{}_os{}".format(
2429 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002430 delim.join([str(x) for x in s]),
2431 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002432 "x".join([str(x) for x in os]),
2433 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00002434 args_dict,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002435 )
TatWai Chong24594f52022-06-08 00:48:04 -07002436 )
2437 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002438
Jeremy Johnson95a67102024-01-10 14:16:39 +00002439 arg_list = TosaArgGen._add_data_generators(
2440 testGen,
2441 opName,
evacha019c96eef2024-02-07 11:21:55 +00002442 shapeList,
Jeremy Johnson95a67102024-01-10 14:16:39 +00002443 dtypes[0],
2444 arg_list,
2445 error_name,
2446 )
2447 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002448 return arg_list
2449
2450 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002451 def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002452 rank = len(shapeList[0])
2453
2454 # Exhaustively test combinations of padding on each side of each dimension
2455 # - the range of padding values is defined by pad_min and pad_max
2456 # - for padding >9, the name format needs to be more distinctive
2457 pad_min, pad_max = 0, 1
2458 pad_values = [x for x in range(pad_min, pad_max + 1)]
2459 if error_name == ErrorIf.PadSmallerZero:
2460 pad_values = [x for x in range(-2, 0)]
2461 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2462 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2463
2464 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002465 pad_const_int = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002466 pad_const_fp = 0
Tai Ly60dc48c2024-03-08 22:19:41 +00002467 elif gtu.dtypeIsFloat(dtype):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002468 pad_const_int = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002469 pad_const_fp = rng.randNumberDType(dtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002470 else:
2471 return []
2472
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002473 list_shape_pad_values = list(shape_pad_values)
2474 # If we are producing tests for rank 6 or greater use sparsity
2475 if len(list_shape_pad_values) > 1024:
2476 sparsity_factor = 2 if error_name else 120
2477 sparsity = TosaArgGen._calculate_sparsity(
2478 len(list_shape_pad_values), sparsity_factor
2479 )
2480 else:
2481 sparsity = 1
2482
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002483 # Build arg list
2484 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002485 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002486 paddings = list(paddings)
2487 args_valid = True
2488
2489 if error_name == ErrorIf.PadSmallerZero:
2490 # Prevent negative output shapes while ensuring still testing for negative padding
2491 for i in range(rank):
2492 dim_after_padding = (
2493 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2494 )
2495 if dim_after_padding < 1:
2496 paddings[i] = (0, 0)
2497 if all([p > -1 for p in paddings[i]]):
2498 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002499 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002500 name = "pad"
2501 for r in range(rank):
2502 before, after = paddings[r]
2503 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002504 args_dict = {
2505 "pad": np.array(paddings),
2506 "pad_const_int": pad_const_int,
2507 "pad_const_fp": pad_const_fp,
2508 }
2509 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002510
2511 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002512 logger.info(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002513
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002514 arg_list = TosaArgGen._add_data_generators(
2515 testGen,
2516 opName,
evacha019c96eef2024-02-07 11:21:55 +00002517 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002518 dtype,
2519 arg_list,
2520 error_name,
2521 )
2522
2523 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002524 return arg_list
2525
2526 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002527 def agPooling(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002528 arg_list = []
2529
2530 shape = shapeList[0]
2531 if error_name != ErrorIf.WrongRank:
2532 assert len(shape) == 4
2533
Jeremy Johnson0c716862023-04-13 17:18:19 +01002534 test_level8k = testGen.args.level8k and error_name is None
2535
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002536 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002537 startKernel = 2
2538 startPad = 0
2539 if not test_level8k:
2540 # Generate comprehensive argument lists
2541 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2542 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2543 # Stride must be greater than 1 to force non-integer error
2544 s_vals = [
2545 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2546 ]
2547 strides = {x for x in itertools.product(*([s_vals] * 2))}
2548 k_vals = [
2549 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2550 ]
2551 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2552 max_dim_size = None
2553 else:
2554 # Only test 8k levels
2555 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2556 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2557 strides = {(1, bigStride), (bigStride, 4)}
2558 kernels = {(1, bigKernel), (bigKernel, 3)}
2559 paddings = set()
2560 for s in sorted(list(strides)):
2561 for k in sorted(list(kernels)):
2562 padding = []
2563 for idx in range(len(k)):
2564 total_padding = s[idx] - shape[idx + 1] + k[idx]
2565 while total_padding < 0:
2566 # Must meet: shape + padding > kernel
2567 total_padding += s[idx]
2568 if total_padding < k[idx]:
2569 padding.extend([0, total_padding])
2570 else:
2571 # Note this may produce padding >= k[idx] which is not
2572 # allowed - but will be ignored in the creation loop below
2573 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2574 paddings.add(tuple(padding))
2575 # Create a limit for the output dimensions size
2576 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002577
James Ward8b390432022-08-12 20:48:56 +01002578 if opName == "max_pool2d":
2579 accum_dtypes = [None] # max_pool has no accumulate dtype
2580 elif dtype == DType.INT8 or dtype == DType.INT16:
2581 accum_dtypes = [DType.INT32]
2582 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002583 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002584 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002585 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002586 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2587 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002588 elif error_name is None:
2589 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2590 else:
2591 # Set to something for the ErrorIf case which has
2592 # incorrect input data-type
2593 accum_dtypes = [DType.INT32]
2594
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002595 if error_name == ErrorIf.WrongAccumulatorType:
2596 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2597
Jeremy Johnson0c716862023-04-13 17:18:19 +01002598 if not test_level8k:
2599 if testGen.args.oversize:
2600 # add some oversize argument values
2601 bigStride = 7
2602 bigKernel = 9
2603 strides.update(
2604 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002605 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002606 kernels.update(
2607 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2608 )
2609 if max(shape) < 64:
2610 # padding must be less than the kernel size
2611 bigPadding = bigKernel - 1
2612 paddings.update(
2613 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2614 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002615
Jeremy Johnson0c716862023-04-13 17:18:19 +01002616 # There are too many parameter combinations, so generate them sparsely,
2617 # very sparse for negative tests
2618 sparsity_factor = 2 if error_name else 500
2619 sparsity = (
2620 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2621 )
2622 else:
2623 # We have already limited test output combinations for 8k tests
2624 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002625
James Ward8b390432022-08-12 20:48:56 +01002626 arg_str = (
2627 "acc{}_st{}_kern{}_pad{}"
2628 if accum_dtypes[0] is not None
2629 else "st{}_kern{}_pad{}"
2630 )
2631
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002632 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002633 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002634 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002635
2636 # Support for larger values than 9 needs different delimiter
2637 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002638 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002639 delim.join([str(x) for x in stride]),
2640 delim.join([str(x) for x in kern]),
2641 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002642 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002643 args_dict = {
2644 "stride": stride,
2645 "pad": pad,
2646 "kernel": kern,
2647 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002648 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002649 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2650 }
James Ward8b390432022-08-12 20:48:56 +01002651
2652 if accum is not None:
2653 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002654 args_dict["acc_type"] = accum
2655 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002656
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002657 n = 0
James Ward8b390432022-08-12 20:48:56 +01002658 for a in accum_dtypes:
2659 for s in sorted(list(strides)):
2660 for p in sorted(list(paddings)):
2661 for k in sorted(list(kernels)):
2662 if error_name in [
2663 ErrorIf.StrideSmallerOne,
2664 ErrorIf.KernelSmallerOne,
2665 ErrorIf.PadSmallerZero,
2666 ErrorIf.PadLargerEqualKernel,
2667 ]:
2668 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002669 rng, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002670 )
James Ward8b390432022-08-12 20:48:56 +01002671 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002672 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002673 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002674 )
James Ward8b390432022-08-12 20:48:56 +01002675 elif (
2676 n % sparsity == 0
2677 # padding must not exceed the kernel size
2678 and p[0] < k[0]
2679 and p[1] < k[0]
2680 and p[2] < k[1]
2681 and p[3] < k[1]
2682 # the padded shape must exceed the kernel size
2683 and (shape[1] + p[0] + p[1]) > k[0]
2684 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002685 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002686 partial_h = shape[1] + p[0] + p[1] - k[0]
2687 partial_w = shape[2] + p[2] + p[3] - k[1]
2688 remainder_h = partial_h % s[0]
2689 remainder_w = partial_w % s[1]
2690 output_h = partial_h // s[0] + 1
2691 output_w = partial_w // s[1] + 1
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002692 logger.debug(
2693 f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
2694 )
James Ward8b390432022-08-12 20:48:56 +01002695 if (
2696 # the parameters must produce integer exact output
2697 error_name != ErrorIf.PoolingOutputShapeNonInteger
2698 and remainder_h == 0
2699 and remainder_w == 0
2700 ) or (
2701 error_name == ErrorIf.PoolingOutputShapeNonInteger
2702 and (remainder_h != 0 or remainder_w != 0)
2703 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002704 if (
2705 max_dim_size is not None
2706 and max(output_h, output_w) > max_dim_size
2707 ):
2708 # Test will consume too much memory - skip it
2709 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002710 # Dot products = N*OH*OW*C
2711 dp = gtu.product(
2712 (shape[0], output_h, output_w, shape[3])
2713 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002714 arg_list.append(
2715 get_arg_list_element(a, s, p, k, dp, shape)
2716 )
James Ward8b390432022-08-12 20:48:56 +01002717 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002718
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002719 # Now add data generator types
2720 arg_list = TosaArgGen._add_data_generators(
2721 testGen,
2722 opName,
evacha019c96eef2024-02-07 11:21:55 +00002723 shapeList,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002724 dtype,
2725 arg_list,
2726 error_name,
2727 )
2728
2729 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002730 return arg_list
2731
2732 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002733 def agCast(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002734 arg_list = []
2735
2736 # Enumerate the output types here
2737 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002738 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(inDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002739 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002740 dtypeList = [
2741 DType.BOOL,
2742 DType.INT16,
2743 DType.INT32,
2744 DType.FP16,
2745 DType.BF16,
2746 DType.FP32,
2747 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002748 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002749 dtypeList = [
2750 DType.BOOL,
2751 DType.INT8,
2752 DType.INT32,
2753 DType.FP16,
2754 DType.BF16,
2755 DType.FP32,
2756 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002757 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002758 dtypeList = [
2759 DType.BOOL,
2760 DType.INT8,
2761 DType.INT16,
2762 DType.FP16,
2763 DType.BF16,
2764 DType.FP32,
2765 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002766 elif inDtype == DType.BOOL:
2767 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002768 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002769 dtypeList = [
2770 DType.INT8,
2771 DType.INT16,
2772 DType.INT32,
2773 DType.FP32,
2774 DType.FP8E4M3,
2775 DType.FP8E5M2,
2776 ]
James Ward24dbc422022-10-19 12:20:31 +01002777 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002778 dtypeList = [
2779 DType.INT8,
2780 DType.INT16,
2781 DType.INT32,
2782 DType.FP32,
2783 DType.FP8E4M3,
2784 DType.FP8E5M2,
2785 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002786 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002787 dtypeList = [
2788 DType.INT8,
2789 DType.INT16,
2790 DType.INT32,
2791 DType.FP16,
2792 DType.BF16,
2793 DType.FP8E4M3,
2794 DType.FP8E5M2,
2795 ]
2796 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2797 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002798 elif error_name == ErrorIf.WrongInputType:
2799 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002800 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002801 else:
2802 raise Exception("Unexpected input dtype: {}".format(inDtype))
2803
2804 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002805 arg_list.append(
2806 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2807 )
2808
2809 # Now add data generator types
2810 arg_list = TosaArgGen._add_data_generators(
2811 testGen,
2812 opName,
evacha019c96eef2024-02-07 11:21:55 +00002813 shapeList,
Jeremy Johnson708da822023-11-15 16:25:45 +00002814 dtype,
2815 arg_list,
2816 error_name,
2817 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002818
2819 return arg_list
2820
2821 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002822 def agRescale(testGen, rng, opName, shapeList, inDtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002823 arg_list = []
2824
2825 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002826 for outDtype in [
2827 DType.UINT8,
2828 DType.INT8,
2829 DType.INT16,
2830 DType.INT32,
2831 DType.UINT16,
2832 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002833 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002834 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002835 and error_name == ErrorIf.OutputZeroPointNotZero
2836 ):
2837 continue
2838 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002839 outDtype != DType.UINT16
2840 and error_name == ErrorIf.U16OutputZeroPointNotValid
2841 ) or (
2842 inDtype != DType.UINT16
2843 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002844 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002845 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002846 continue
2847 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002848 inDtype == DType.UINT8
2849 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002850 and error_name != ErrorIf.WrongOutputType
2851 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002852 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2853 continue
2854 if (
2855 inDtype not in [DType.INT8, DType.INT16]
2856 and outDtype == DType.UINT8
2857 and error_name != ErrorIf.WrongOutputType
2858 ):
2859 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2860 continue
2861 if (
2862 inDtype == DType.UINT16
2863 and outDtype != DType.INT16
2864 and error_name != ErrorIf.WrongOutputType
2865 ):
2866 # The only output dtype for UINT16 is INT16, skip all others
2867 continue
2868 if (
2869 inDtype != DType.INT16
2870 and outDtype == DType.UINT16
2871 and error_name != ErrorIf.WrongOutputType
2872 ):
2873 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002874 continue
2875 if (
2876 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002877 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002878 ):
2879 continue
2880
2881 for scale32 in [False, True]:
2882 if error_name == ErrorIf.ScaleTrue and not scale32:
2883 continue
2884 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2885 continue
2886 for double_round in [False, True]:
2887 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2888 continue
2889 for per_channel in [False, True]:
2890
2891 if (
2892 inDtype == DType.INT48
2893 and scale32
2894 and error_name != ErrorIf.ScaleTrue
2895 ):
2896 # Illegal condition. Must be scale32=False
2897 continue
2898 if (
2899 double_round
2900 and not scale32
2901 and error_name != ErrorIf.ScaleNotTrue
2902 ):
2903 # Illegal condition. ERROR_IF(!scale32 && double_round)
2904 continue
2905
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002906 if per_channel:
2907 nc = shapeList[0][-1]
2908 else:
2909 nc = 1
2910
2911 in_type_width = gtu.dtypeWidth(inDtype)
2912 out_type_width = gtu.dtypeWidth(outDtype)
2913
2914 # Calculate scale based on:
2915 # scale = a *(2^output_width)/(2^input_width))
2916
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002917 a = np.float32(rng.random(size=[nc]))
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002918 scale_arr = a * np.float32(
2919 (1 << out_type_width) / (1 << in_type_width)
2920 )
2921
2922 if scale32:
2923 # Cap the scaling at 2^31 - 1 for scale32
2924 scale_arr = np.clip(
2925 scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
2926 )
2927 else:
2928 # Cap the scaling at 2^15 - 1 for scale16
2929 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2930
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002931 logger.debug(
2932 f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
2933 )
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002934
2935 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2936 shift_arr = np.int32(np.zeros(shape=[nc]))
2937 for i in range(nc):
2938 (
2939 multiplier_arr[i],
2940 shift_arr[i],
2941 ) = TosaQuantGen.computeMultiplierAndShift(
2942 scale_arr[i], scale32
2943 )
2944
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002945 arg_list.append(
2946 (
2947 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002948 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002949 int(scale32),
2950 int(double_round),
2951 int(per_channel),
2952 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00002953 {
2954 "output_dtype": outDtype,
2955 "scale": scale32,
2956 "double_round": double_round,
2957 "per_channel": per_channel,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002958 "multiplier": multiplier_arr,
2959 "shift": shift_arr,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002960 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002961 )
2962 )
2963
Jeremy Johnson587cc842024-02-08 11:45:44 +00002964 arg_list = TosaArgGen._add_data_generators(
2965 testGen,
2966 opName,
evacha019c96eef2024-02-07 11:21:55 +00002967 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002968 inDtype,
2969 arg_list,
2970 error_name,
2971 )
2972 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002973 return arg_list
2974
2975 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002976 def agMul(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002977 arg_list = []
2978
2979 if dtype is DType.INT32:
2980 for p in range(testGen.args.num_rand_permutations):
2981
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002982 shift = rng.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002983 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002984 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002985 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002986
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002987 arg_list = TosaArgGen._add_data_generators(
2988 testGen,
2989 opName,
evacha019c96eef2024-02-07 11:21:55 +00002990 shapeList,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002991 dtype,
2992 arg_list,
2993 error_name,
2994 )
2995 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002996 return arg_list
2997
2998 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002999 def agArithmeticRightShift(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003000 arg_list = []
3001
Jeremy Johnson587cc842024-02-08 11:45:44 +00003002 for round in (True, False):
3003 args_dict = {
3004 "round": round,
3005 }
3006 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003007
Jeremy Johnson587cc842024-02-08 11:45:44 +00003008 arg_list = TosaArgGen._add_data_generators(
3009 testGen,
3010 opName,
evacha019c96eef2024-02-07 11:21:55 +00003011 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003012 dtype,
3013 arg_list,
3014 error_name,
3015 )
3016 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003017 return arg_list
3018
Luke Hutton57287132023-02-06 14:54:18 +00003019 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003020 def agFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Luke Hutton57287132023-02-06 14:54:18 +00003021 arg_list = []
3022
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003023 shape = shapeList[0]
3024 dot_products = gtu.product(shape)
3025 ks = 2 * shape[1] * shape[2] # 2*H*W
3026 for inverse in (True, False):
3027 args_dict = {
3028 "dot_products": dot_products,
3029 "shape": shape,
3030 "ks": ks,
3031 "acc_type": dtype,
3032 "inverse": inverse,
3033 }
3034 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00003035
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003036 arg_list = TosaArgGen._add_data_generators(
3037 testGen,
3038 opName,
evacha019c96eef2024-02-07 11:21:55 +00003039 shapeList,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003040 dtype,
3041 arg_list,
3042 error_name,
3043 )
3044 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00003045 return arg_list
3046
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003047 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003048 def agRFFT2d(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003049 arg_list = []
3050
3051 shape = shapeList[0]
3052 dot_products = gtu.product(shape)
3053 ks = shape[1] * shape[2] # H*W
3054 args_dict = {
3055 "dot_products": dot_products,
3056 "shape": shape,
3057 "ks": ks,
3058 "acc_type": dtype,
3059 }
3060 arg_list.append(("", args_dict))
3061
3062 arg_list = TosaArgGen._add_data_generators(
3063 testGen,
3064 opName,
evacha019c96eef2024-02-07 11:21:55 +00003065 shapeList,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00003066 dtype,
3067 arg_list,
3068 error_name,
3069 )
3070 # Return list of tuples: (arg_str, args_dict)
3071 return arg_list
3072
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003073 # Helper function for reshape. Gets some factors of a larger number.
3074 @staticmethod
3075 def getFactors(val, start=1):
3076 factors = []
3077
3078 for i in range(start, int(np.sqrt(val)) + 1):
3079 if (val % i) == 0:
3080 factors.append(i)
3081
3082 return factors
3083
3084 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003085 def agReshape(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003086 arg_list = []
3087
3088 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003089 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003090 factors = TosaArgGen.getFactors(totalElements)
3091
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003092 # Find new shapes up to the number of permutations asked for
3093 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003094 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00003095 # Rank from 1 to TOSA_TENSOR_MAX_RANK
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003096 newRank = rng.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003097 if len(factors) < newRank:
3098 continue
3099
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003100 # escape_counter limits the generation of new shapes to a reasonable time
3101 for escape_counter in range(100):
3102
3103 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003104 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003105 remainingElements = totalElements
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003106 shuffledFactors = rng.permutation(factors)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003107 for i in range(1, newRank):
3108 # pick rank-1 factors
3109 newShape.append(shuffledFactors[0])
3110 remainingElements = remainingElements // shuffledFactors[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003111 shuffledFactors = rng.permutation(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003112 TosaArgGen.getFactors(remainingElements)
3113 )
3114 newShape.append(remainingElements)
3115
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003116 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003117 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003118 for name, args_dict in arg_list:
3119 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003120 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003121 break
3122
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003123 if not duplicate:
3124 outShape = "x".join([str(x) for x in newShape])
3125 arg_list.append(
3126 (
3127 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3128 {"new_shape": newShape},
3129 )
3130 )
3131 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003132 break
3133
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003134 # Now add data generator types
3135 arg_list = TosaArgGen._add_data_generators(
3136 testGen,
3137 opName,
evacha019c96eef2024-02-07 11:21:55 +00003138 shapeList,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003139 dtype,
3140 arg_list,
3141 error_name,
3142 )
3143
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003144 return arg_list
3145
3146 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003147 def agTranspose(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003148 arg_list = []
3149
3150 ifm_shape = shapeList[0]
3151
3152 if error_name == ErrorIf.IndexOutsideBounds:
3153 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3154 incorrect_small_index = range(-len(ifm_shape), 0)
3155 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3156 permutations.extend(
3157 [p for p in itertools.permutations(incorrect_small_index)]
3158 )
3159 elif error_name == ErrorIf.IndexUsedTwice:
3160 # Create list with a duplicated index
3161 perm_range = list(range(len(ifm_shape)))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003162 index_choice = rng.choice(range(len(perm_range)))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003163 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3164 permutations = [p for p in itertools.permutations(perm_range)]
3165
3166 else:
3167 # Get all permutations
3168 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3169
3170 # Limit to possible permutations from shape dimension or argument setting
3171 limit = min(len(permutations), testGen.args.num_rand_permutations)
3172
3173 # Get random permutation generator that uses all permutations
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003174 random_permutations = rng.permutation(permutations)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003175
3176 # Create list of required amount of permutations
3177 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003178 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003179 for p in range(limit)
3180 ]
evacha0198477222024-01-26 12:25:32 +00003181 # Now add data generator types
3182 arg_list = TosaArgGen._add_data_generators(
3183 testGen,
3184 opName,
evacha019c96eef2024-02-07 11:21:55 +00003185 shapeList,
evacha0198477222024-01-26 12:25:32 +00003186 dtype,
3187 arg_list,
3188 error_name,
3189 )
3190 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003191 return arg_list
3192
3193 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003194 def agSlice(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003195 arg_list = []
3196
3197 ifm_shape = shapeList[0]
3198 rank = len(ifm_shape)
3199
3200 for p in range(testGen.args.num_rand_permutations):
3201 start = []
3202 size = []
3203
3204 valid = True
3205
3206 for i in range(rank):
3207 if ifm_shape[i] > 1:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003208 start.append(rng.randInt(0, ifm_shape[i]))
3209 size.append(rng.randInt(0, ifm_shape[i] - start[i]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003210
3211 # Invalid slice size?
3212 if size[i] == 0:
3213 valid = False
3214 else:
3215 start.append(0)
3216 size.append(1)
3217
3218 if valid:
3219 # If ERROR_IF test required then incorrect start, size will be returned
3220 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003221 rng, error_name, ifm_shape, start, size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003222 )
evacha017f7d4252024-01-24 12:08:09 +00003223 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3224 # Now add data generator types
3225 arg_list = TosaArgGen._add_data_generators(
3226 testGen,
3227 opName,
evacha019c96eef2024-02-07 11:21:55 +00003228 shapeList,
evacha017f7d4252024-01-24 12:08:09 +00003229 dtype,
3230 arg_list,
3231 error_name,
3232 )
3233 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003234 return arg_list
3235
3236 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003237 def agTile(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003238 arg_list = []
3239
3240 ifm_shape = shapeList[0]
3241 rank = len(ifm_shape)
3242
3243 for p in range(testGen.args.num_rand_permutations):
3244
3245 # Pick a few random, but small multiple values
3246 # because otherwise this has a tendency to generate
3247 # enormous tensors
3248 multiples = []
3249 for i in range(rank):
3250 if ifm_shape[i] > 1000:
3251 # Multiple of 1 if ifm_shape dimension is large to reduce
3252 # tensor size
3253 multiples.append(1)
3254 elif max(ifm_shape) > 1000:
3255 multiples.append(2)
3256 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003257 multiples.append(rng.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003258 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003259
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003260 # Now add data generator types
3261 arg_list = TosaArgGen._add_data_generators(
3262 testGen,
3263 opName,
evacha019c96eef2024-02-07 11:21:55 +00003264 shapeList,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003265 dtype,
3266 arg_list,
3267 error_name,
3268 )
3269 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003270 return arg_list
3271
3272 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003273 def agResize(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003274 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003275 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003276
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003277 def get_aspect_ratio_resize_params():
3278 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003279 aspect_ratio = rng.choice(common_aspect_ratios)
3280 invert = rng.choice((False, True))
3281 letterbox = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003282
3283 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3284 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3285 scale_y_d = scale_x_d = 1
3286 offset_x = offset_y = 0
3287
3288 if letterbox:
3289 max_border = scale_y_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003290 border_y = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003291 border_x = 0
3292 else:
3293 # Pillarboxing
3294 border_y = 0
3295 max_border = scale_x_n
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003296 border_x = rng.randInt(low=0, high=max_border)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003297
3298 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3299 offset = (offset_y, offset_x)
3300 border = (border_y, border_x)
3301
3302 return scale, offset, border
3303
3304 def get_upscale_downscale_params():
3305 valid_params = False
3306 while not valid_params:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003307 upscale = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003308
3309 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003310 origin_sampling = rng.choice((False, True))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003311
3312 if upscale:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003313 shift = rng.randInt(low=1, high=4)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003314 scale_x_d = scale_y_d = 1
3315 scale_x_n = scale_y_n = (
3316 1 << shift if origin_sampling else 2 << shift
3317 )
3318 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3319 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3320 else:
3321 scale_x_n = 1
3322 scale_y_n = 1
3323
3324 # Return list of valid scale_*_d values (max value 4) given input dim shape
3325 def get_valid_denom(ifm_dim):
3326 return [x for x in range(1, 5) if ifm_dim % x == 1]
3327
3328 # Generate list of valid downscale values and choose one randomly
3329 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3330 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3331
3332 if not valid_scale_y_ds and not valid_scale_x_ds:
3333 # Bad parameters, skip
3334 continue
3335
3336 if not valid_scale_y_ds:
3337 scale_y_d = 1
3338 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003339 scale_y_d = rng.choice(valid_scale_y_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003340
3341 if not valid_scale_x_ds:
3342 scale_x_d = 1
3343 else:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003344 scale_x_d = rng.choice(valid_scale_x_ds)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003345
3346 border_x = border_y = 0
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003347 offset_y = rng.randInt(0, 16 * scale_y_n)
3348 offset_x = rng.randInt(0, 16 * scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003349 valid_params = True
3350
3351 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3352 offset = (offset_y, offset_x)
3353 border = (border_y, border_x)
3354 return scale, offset, border
3355
3356 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003357 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3358 scale = scale_n / scale_d
3359 if scale > max_scale:
3360 factor = scale / max_scale
3361 new_scale_d = math.ceil(scale_d * factor)
3362 assert scale_n / new_scale_d <= max_scale
3363 scale_d = new_scale_d
3364 return scale_d
3365
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003366 # Scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003367 scale_y_n = rng.randInt(low=1, high=(1 << 11))
3368 scale_x_n = rng.randInt(low=1, high=(1 << 11))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003369
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003370 scale_y_d = rng.randInt(low=1, high=(16 * scale_y_n))
3371 scale_x_d = rng.randInt(low=1, high=(16 * scale_x_n))
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003372
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003373 scale_y_d = fix_scale_to_max_scale(
3374 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3375 )
3376 scale_x_d = fix_scale_to_max_scale(
3377 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3378 )
3379
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003380 # Offsets and border within the scale
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003381 offset_y = rng.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3382 offset_x = rng.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3383 border_y = rng.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3384 border_x = rng.randInt(low=(-16 * scale_x_n), high=scale_x_n)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003385
3386 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3387 offset = (offset_y, offset_x)
3388 border = (border_y, border_x)
3389 return scale, offset, border
3390
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003391 def get_level_8k_params():
3392 # Create 64x scale - 64/1 to 2048/32
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003393 scale_d = rng.randInt(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003394 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3395 )
3396 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3397 # Create half to fifth scaling
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003398 scale_d_alt = rng.randInt(low=2, high=6)
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003399 scale_n_alt = 1
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003400 switch = rng.choice((False, True))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003401 if switch:
3402 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3403 else:
3404 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3405
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003406 offset_y = rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3407 offset_x = rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003408 offset = (offset_y, offset_x)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003409 border_y = rng.choice((-16 * scale[0], 0, scale[0] - 1))
3410 border_x = rng.choice((-16 * scale[2], 0, scale[2] - 1))
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003411 border = (border_y, border_x)
3412 return scale, offset, border
3413
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003414 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003415 # Exclude illegal {mode, type} configurations. Pick legal output types
3416 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3417 outputDTypeList = [DType.INT8]
3418 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3419 outputDTypeList = [DType.INT16]
3420 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3421 outputDTypeList = [DType.INT32]
3422 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3423 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003424 elif dtype == DType.FP16:
3425 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003426 elif dtype == DType.BF16:
3427 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003428 elif dtype == DType.FP32:
3429 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003430 elif dtype == DType.FP8E4M3:
3431 outputDTypeList = [DType.FP8E4M3]
3432 elif dtype == DType.FP8E5M2:
3433 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003434 elif error_name == ErrorIf.WrongInputType:
3435 # If an incorrect input type is used then we set a 'correct'
3436 # output type to avoid other errors
3437 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3438 else:
3439 continue
3440
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003441 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3442
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003443 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003444 perm = 0
3445 while perm < testGen.args.num_rand_permutations:
3446 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003447 if not testGen.args.level8k:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003448 _rnd_param_fn = rng.choice(
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003449 (
3450 get_rand_params,
3451 get_upscale_downscale_params,
3452 get_aspect_ratio_resize_params,
3453 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003454 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003455 scale, offset, border = _rnd_param_fn()
3456 else:
3457 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003458
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003459 # Expand params for bounds-checking
3460 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3461 (offset_y, offset_x) = offset
3462 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003463
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003464 # Make sure output dimensions OH and OW are integers
3465 partial_output_y = (
3466 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3467 )
3468 partial_output_x = (
3469 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3470 )
3471 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003472 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003473 if (
3474 partial_output_y % scale_y_d == 0
3475 and partial_output_x % scale_x_d == 0
3476 ):
3477 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003478 if perm > 0:
3479 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003480 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003481 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003482 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003483 while partial_output_y % scale_y_d != 0:
3484 scale_y_d -= 1
3485 while partial_output_x % scale_x_d != 0:
3486 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003487 # Make sure we are still within max scaling
3488 if (
3489 scale_y_n / scale_y_d
3490 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3491 scale_x_n / scale_x_d
3492 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3493 # Skip the test as it is using too large a scaling factor
3494 if perm > 0:
3495 perm += 1
3496 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003497
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003498 output_y = partial_output_y // scale_y_d + 1
3499 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003500
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003501 if (
3502 output_y >= testGen.args.max_resize_output_dim
3503 or output_x >= testGen.args.max_resize_output_dim
3504 ) and error_name is None:
3505 # Skip positive test if output dim will be too high
3506 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003507 if not testGen.args.level8k or perm > 0:
3508 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003509 continue
3510
3511 if (
3512 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003513 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003514 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003515 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003516 ):
3517 # Output dimensions out of scope
3518 if error_name is not None and perm > 0:
3519 # As long as we have one ERROR_IF test, don't worry
3520 # about creating all the other permutations
3521 perm += 1
3522 continue
3523
3524 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3525 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003526 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003527 and output_y - scale_y_d < 1
3528 )
3529 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003530 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003531 and output_x - scale_x_d < 1
3532 )
3533 ):
3534 # Can't create a negative test with these params as it
3535 # will create invalid output size
3536 if perm > 0:
3537 perm += 1
3538 continue
3539
3540 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3541 offset = [offset_y, offset_x]
3542 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003543
3544 # Common for all data types
3545 if error_name is not None:
3546 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003547 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003548 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003549 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003550 outputDTypeNew,
3551 ) = TosaErrorIfArgGen.eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003552 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003553 error_name,
3554 mode,
3555 dtype,
3556 shapeList,
3557 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003558 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003559 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003560 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003561 )
3562 else:
3563 outputDTypeNew = outputDType
3564
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003565 arg_to_append = (
3566 arg_str.format(
3567 "N" if mode == ResizeMode.NEAREST else "B",
3568 testGen.typeStr(outputDTypeNew),
3569 scale[0],
3570 scale[1],
3571 scale[2],
3572 scale[3],
3573 offset[0],
3574 offset[1],
3575 border[0],
3576 border[1],
3577 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003578 {
3579 "mode": mode,
3580 "scale": scale,
3581 "offset": offset,
3582 "border": border,
3583 "output_dtype": outputDTypeNew,
3584 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003585 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003586 if arg_to_append in arg_list:
3587 # Skip already generated test params
3588 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003589
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003590 # Valid permutation
3591 perm += 1
3592 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003593
3594 # Now add data generator types
3595 arg_list = TosaArgGen._add_data_generators(
3596 testGen,
3597 opName,
evacha019c96eef2024-02-07 11:21:55 +00003598 shapeList,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003599 dtype,
3600 arg_list,
3601 error_name,
3602 )
3603 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003604 return arg_list
3605
3606 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003607 def agTable(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003608 arg_list = []
3609
3610 if dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003611 table = np.int32(rng.integers(low=-128, high=128, size=[256])).tolist()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003612 else: # INT16
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003613 table = np.int32(rng.integers(low=-32768, high=32768, size=[513])).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003614 # Make sure all slopes are within REQUIRE min/max 16-bit int
3615 for idx in range(len(table) - 1):
3616 slope = table[idx + 1] - table[idx]
3617 # Alter the next table entry to force the slope to be ok
3618 if slope > 32767:
3619 table[idx + 1] -= slope - 32767
3620 if slope < -32768:
3621 table[idx + 1] -= slope + 32768
3622 slope = table[idx + 1] - table[idx]
3623 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003624 arg_list.append(
3625 (
3626 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003627 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003628 )
3629 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003630 # Now add data generator types
3631 arg_list = TosaArgGen._add_data_generators(
3632 testGen,
3633 opName,
evacha019c96eef2024-02-07 11:21:55 +00003634 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003635 dtype,
3636 arg_list,
3637 error_name,
3638 )
3639 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003640 return arg_list
3641
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003642 def agCondIf(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003643 # CondIf generates the condition values here.
3644 # Convert to tensors in the build function, along with the
3645 # then and else blocks
3646 arg_list = []
3647
3648 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003649 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003650
Jeremy Johnson587cc842024-02-08 11:45:44 +00003651 # Now add data generator types
3652 arg_list = TosaArgGen._add_data_generators(
3653 testGen,
3654 opName,
evacha019c96eef2024-02-07 11:21:55 +00003655 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003656 dtype,
3657 arg_list,
3658 error_name,
3659 )
3660 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003661 return arg_list
3662
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003663 def agWhileLoop(testGen, rng, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003664 # While loop: 0 iterations, 1, more than 1
3665 arg_list = []
3666
Jeremy Johnson587cc842024-02-08 11:45:44 +00003667 for iterations in [0, 1, 4]:
3668 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003669
Jeremy Johnson587cc842024-02-08 11:45:44 +00003670 # Now add data generator types
3671 arg_list = TosaArgGen._add_data_generators(
3672 testGen,
3673 opName,
evacha019c96eef2024-02-07 11:21:55 +00003674 shapeList,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003675 dtype,
3676 arg_list,
3677 error_name,
3678 )
3679 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003680 return arg_list