blob: 2384e0309a4fbac1c1a5924c94f4c121e27b5719 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
Kevin Cheng550ccc52021-03-03 11:21:43 -080052
Eric Kunzee5e26762020-10-13 16:11:07 -070053class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080054 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
55
Eric Kunzee5e26762020-10-13 16:11:07 -070056 def __init__(self):
57 pass
58
59 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010060 def getQinfo(testGen, dtype):
61 if dtype == DType.INT8:
62 return testGen.randInt(-128, 128)
63 if dtype == DType.UINT8:
64 return testGen.randInt(0, 256)
65 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070066
67 @staticmethod
68 def qgUnary(testGen, op, dtype):
69 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070070 qinfo.UnaryQuantInfo(
71 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
72 )
Eric Kunzee5e26762020-10-13 16:11:07 -070073 return qinfo
74
75 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010076 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070077 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010078 if isinstance(dtype_or_dtypeList, list):
79 # a list of [input, weights, accumulator] dtypes
80 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070081 else:
Les Bell30e46802021-07-23 09:43:31 +010082 # an int, [input, weights, accumulator] dtypes are the same
83 dtypeList = [dtype_or_dtypeList] * 3
84 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
85 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
86 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070087 return qinfo
88
89 @staticmethod
90 def qgMatmul(testGen, op, dtype):
91 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070092 qinfo.MatMulQuantInfo(
93 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
94 )
Eric Kunzee5e26762020-10-13 16:11:07 -070095 return qinfo
96
97 @staticmethod
98 def qgPad(testGen, op, dtype):
99 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +0100100 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 return qinfo
102
103 @staticmethod
104 def computeMultiplierAndShift(scaleFp, scale32):
105 # Derived from computeMultiplierAndShiftTosaScale32
106 # Provide a floating-point scaling factor and the scale32 parameter
107 # to compute the multiplier and shift
108
109 if scale32:
110 scaleBits = 31
111 else:
112 scaleBits = 15
113
114 m, shift = math.frexp(scaleFp)
115
116 if scaleFp < 0.0:
117 m = -m
118
119 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800120 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 if multiplier == (1 << scaleBits):
123 multiplier = multiplier // 2
124 shift = shift + 1
125
126 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100127 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
128
129 # Adjust multiplier such that shift is in allowed value range.
130 if shift == 0:
131 multiplier = multiplier // 4
132 shift = shift + 2
133 elif shift == 1:
134 multiplier = multiplier // 2
135 shift = shift + 1
136 elif shift == 63:
137 multiplier = multiplier * 2
138 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100141 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700142
143 return multiplier, shift
144
145
Kevin Cheng550ccc52021-03-03 11:21:43 -0800146class TosaTensorGen:
147 """Tensor generators create a shape list for the placeholder and const tensor
148 data operands for the operator. The actual random data is generated separately for each test."""
149
Eric Kunzee5e26762020-10-13 16:11:07 -0700150 def __init__(self):
151 pass
152
153 @staticmethod
154 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800155 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700156 shape = testGen.makeShape(rank)
157
158 shape_list = []
159 for i in range(pl + const):
160 shape_list.append(shape.copy())
161
162 return shape_list
163
164 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100165 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800166 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700167
Kevin Cheng550ccc52021-03-03 11:21:43 -0800168 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700169
170 shape = testGen.makeShape(rank)
171
172 # Constrict the batch size?
173 if testGen.args.max_batch_size:
174 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
175
176 shape_list = []
177 for i in range(pl + const):
178 shape_list.append(shape.copy())
179
180 return shape_list
181
182 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800183 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800184 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800185
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 assert pl == 2
187 assert const == 0
188 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800189
190 values_in_shape = testGen.makeShape(rank)
191
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100192 # ignore max batch size if target shape is set
193 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800194 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 W = testGen.randInt(
197 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
198 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100199 # Constrict W if one dimension is too large to keep tensor size reasonable
200 if max(values_in_shape) > 5000:
201 W = testGen.randInt(0, 16)
202
Kevin Cheng77d0f762020-11-24 10:26:32 -0800203 input_shape = [values_in_shape[0], W, values_in_shape[2]]
204
205 shape_list = []
206 shape_list.append(values_in_shape.copy())
207 shape_list.append(input_shape.copy())
208
209 return shape_list
210
211 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 def tgBroadcastFuzz(testGen, op, rank):
213 shape = testGen.makeShape(rank)
214
Kevin Cheng550ccc52021-03-03 11:21:43 -0800215 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700216
217 shape_list = []
218
219 # Choose one of the inputs to broadcast
220 bcast_idx = testGen.randInt(0, pl + const)
221 for i in range(pl + const):
222 shape_bcast = shape.copy()
223
224 # If the chosen input, pick a random index to broadcast
225 if i == bcast_idx:
226 fuzz_idx = testGen.randInt(0, rank)
227 shape_bcast[fuzz_idx] = 1
228
229 shape_list.append(shape_bcast)
230
231 return shape_list
232
233 @staticmethod
234 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800235 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700236
Kevin Cheng550ccc52021-03-03 11:21:43 -0800237 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
239 # IFM dimensions are NHWC
240 ifm_shape = testGen.makeShape(rank)
241
242 # Constrict the batch size?
243 if testGen.args.max_batch_size:
244 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
245
246 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800247 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700248
249 # Generate a random OFM depth
250 ofm_depth = testGen.makeShape(1)[0]
251
252 # The filter dimensions are OHWI
253 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
254
255 # The bias is OC
256 bias_shape = np.asarray([ofm_depth])
257
258 return [ifm_shape, filter_shape, bias_shape]
259
260 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -0700261 def tgConv3D(testGen, op, rank):
262 pl, const = op["operands"]
263
264 assert rank == 5
265
266 # IFM dimensions are NDHWC
267 ifm_shape = testGen.makeShape(rank)
268
269 # Constrict the batch size?
270 if testGen.args.max_batch_size:
271 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
272
273 # Get the filter depth/height/width from the operator parameters
274 filter_dhw = op["filter"]
275
276 # Generate a random OFM channel
277 ofm_channel = testGen.makeShape(1)[0]
278
279 # The filter dimensions are ODHWI
280 filter_shape = np.asarray(
281 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
282 )
283
284 # The bias is OC
285 bias_shape = np.asarray([ofm_channel])
286
287 return [ifm_shape, filter_shape, bias_shape]
288
289 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700290 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800291 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700292
Kevin Cheng550ccc52021-03-03 11:21:43 -0800293 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
295 # IFM dimensions are NHWC
296 ifm_shape = testGen.makeShape(rank)
297
298 # Constrict the batch size?
299 if testGen.args.max_batch_size:
300 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
301
302 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800303 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304
305 # Generate a random OFM depth
306 ofm_depth = testGen.makeShape(1)[0]
307
308 # The filter dimensions are OHWI
309 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
310
Kevin Cheng989cb052021-04-28 16:29:44 -0700311 # The bias is OC
312 bias_shape = np.asarray([ofm_depth])
313
314 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700315
316 @staticmethod
317 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800318 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
Kevin Cheng550ccc52021-03-03 11:21:43 -0800320 assert rank == 4
321 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700322
323 # IFM dimensions are NHWC
324 ifm_shape = testGen.makeShape(rank)
325
326 # Constrict the batch size?
327 if testGen.args.max_batch_size:
328 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
329
330 # Get the filter height/width from the operator parameters
331 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800332 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700333
334 # Generate a random OFM depth, but don't let it get too big because
335 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800336 filter_m = (
337 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
338 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700339
340 # The filter dimensions are HWCM
341 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
342
343 # The bias is M * C
344 bias_shape = np.asarray([ifm_shape[3] * filter_m])
345
346 return [ifm_shape, filter_shape, bias_shape]
347
348 @staticmethod
349 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800350 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700351
Kevin Cheng550ccc52021-03-03 11:21:43 -0800352 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700355 filter_oc = testGen.rng.integers(
356 low=testGen.args.tensor_shape_range[0],
357 high=testGen.args.tensor_shape_range[1],
358 size=1,
359 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700360 filter_shape = np.asarray([filter_oc, input_shape[1]])
361
362 bias_shape = np.asarray([filter_oc])
363
364 return [input_shape, filter_shape, bias_shape]
365
366 @staticmethod
367 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800368 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
Kevin Cheng2d60f002021-06-09 14:18:32 -0700370 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800371 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700372
373 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100374 # Get a random number for b_oc even if target shape is defined
375 b_oc = np.int32(
376 testGen.rng.integers(
377 low=testGen.args.tensor_shape_range[0],
378 high=testGen.args.tensor_shape_range[1],
379 size=1,
380 )
381 )[0]
382 # If N or H is large let b_oc be 1 to reduce output tensor size
383 if max(a_shape) > 1000:
384 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700385
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100386 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700387 return [a_shape, b_shape]
388
Matthew Haddon818ab902021-07-27 09:12:49 +0100389 @staticmethod
390 def tgConcat(testGen, opName, rank):
391 pl, const = opName["operands"]
392 shape = testGen.makeShape(rank)
393
394 # Create extra tensors to concat.
395 # Take into account value of pl when getting maximum number of concats
396 num_tensors = testGen.randInt(0, 4)
397 shape_list = []
398 for i in range(pl + const + num_tensors):
399 shape_list.append(shape.copy())
400
401 return shape_list
402
403 @staticmethod
404 def tgConcatConstInput(testGen, shapeList, axis):
405 # Split concat shape along axis to allow for multiple const inputs
406 # without making too many large tensors
407 shape = shapeList[0]
408 if len(shapeList) == 2 or shape[axis] < len(shapeList):
409 return shapeList
410
411 new_shapeList = [shape.copy()]
412 length_on_axis = shape[axis]
413 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700414 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100415 # Calculate split on axis and remaining value
416 split_shape_val = int(shape[axis] / 2)
417 remaining_length = remaining_length - split_shape_val
418
419 # Append new shape, and set remaining shape
420 shape[axis] = split_shape_val
421 new_shapeList.append(shape.copy())
422 shape[axis] = remaining_length
423 if i == len(shapeList) - 3:
424 new_shapeList.append(shape.copy())
425
426 return new_shapeList
427
428
Eric Kunzee5e26762020-10-13 16:11:07 -0700429class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800430 """Argument generators create exhaustive or random lists of attributes for operators that take
431 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
432 tuples where the descriptive_name is appended to the test name and the arglist is expanded
433 as arguments to the operator build function."""
434
Eric Kunzee5e26762020-10-13 16:11:07 -0700435 def __init__(self):
436 pass
437
438 @staticmethod
439 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800440 """A trivial argument generator for operators that don't take any
441 non-tensor arguments"""
442 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700443
444 @staticmethod
445 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800446 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700447 axes = []
448
449 shape = shapeList[0]
450
451 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100452 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700453 return axes
454
455 @staticmethod
Les Bell7aa69f42021-09-20 10:44:07 +0100456 def agConv(testGen, opName, shapeList, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700457 arg_list = []
458
459 ifm_shape = shapeList[0]
460 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100461 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
462 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700463
Les Bell7aa69f42021-09-20 10:44:07 +0100464 # Check the rank
465 rank = 5 if opName.startswith("conv3d") else 4
466 assert len(ifm_shape) == rank
467 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700468
Les Bell7aa69f42021-09-20 10:44:07 +0100469 # kernel rank omits batch and channels
470 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700471
Les Bell7aa69f42021-09-20 10:44:07 +0100472 # Generate comprehensive argument lists
473 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
474 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
475 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
476 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
477 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
478 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700479
Les Bell7aa69f42021-09-20 10:44:07 +0100480 # add some oversize argument values
481 if max(ifm_shape) < 64:
482 bigPadding = 9
483 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
484 bigStride = 8
485 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
486 bigDilation = 7
487 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100488
489 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100490 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
491 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
492 if sparsity < 13:
493 sparsity = 1
494 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
495 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100496 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100497 for s in sorted(list(strides)):
498 for p in sorted(list(paddings)):
499 for d in sorted(list(dilations)):
500 if (n % sparsity == 0
501 # padding must not exceed the kernel size ?
502 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
503 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
504 # the padded shape must exceed the kernel size
505 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
506 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
507 # the padded shape must exceed the dilation
508 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
509 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
510 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100511 arg_list.append(
512 (
513 "st{}_pad{}_dilat{}".format(
514 "".join([str(x) for x in s]),
515 "".join([str(x) for x in p]),
516 "".join([str(x) for x in d]),
517 ),
518 [s, p, d],
519 )
520 )
521 n += 1
522
Kevin Cheng1533b852021-09-01 12:51:58 -0700523 return arg_list
524
525 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700526 def agTransposeConv2D(testGen, opName, shapeList, dtype):
527 arg_list = []
528
529 ifm_shape = shapeList[0]
530 filter_shape = shapeList[1]
531
532 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800533 assert len(ifm_shape) == 4
534 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
Les Bell7aa69f42021-09-20 10:44:07 +0100536 # Generate comprehensive argument lists
537 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
538 paddings = {x for x in itertools.product(*([p_vals] * 2))}
539 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
540 strides = {x for x in itertools.product(*([s_vals] * 2))}
541 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
542 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700543
Les Bell7aa69f42021-09-20 10:44:07 +0100544 # add some oversize argument values
545 if max(ifm_shape) < 64:
546 bigPadding = 9
547 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
548 bigStride = 8
549 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
550 bigDilation = 7
551 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700552
Les Bell7aa69f42021-09-20 10:44:07 +0100553 # There are too many parameter combinations, so generate them sparsely
554 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
555 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
556 if sparsity < 13:
557 sparsity = 1
558 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
559 sparsity += 1
560 n = 0
561 for s in sorted(list(strides)):
562 for p in sorted(list(paddings)):
563 for d in sorted(list(dilations)):
564 if n % sparsity == 0:
565 # Determine the output shape
566 oh = (
567 ifm_shape[1]
568 - filter_shape[1]
569 - (filter_shape[1] - 1) * (d[0] - 1)
570 + 2 * p[0]
571 ) // s[0] + 1
572 ow = (
573 ifm_shape[2]
574 - filter_shape[2]
575 - (filter_shape[2] - 1) * (d[1] - 1)
576 + 2 * p[1]
577 ) // s[1] + 1
578 os = [ifm_shape[0], oh, ow, filter_shape[0]]
579 arg_list.append(
580 (
581 "st{}_pad{}_dilat{}_os{}".format(
582 "".join([str(x) for x in s]),
583 "".join([str(x) for x in p]),
584 "".join([str(x) for x in d]),
585 "x".join([str(x) for x in os]),
586 ),
587 [s, p, d, os],
588 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800589 )
Les Bell7aa69f42021-09-20 10:44:07 +0100590 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700591
592 return arg_list
593
594 @staticmethod
595 def agPad(testGen, opName, shapeList, dtype):
596 arg_list = []
597 rank = len(shapeList[0])
598
Les Bell7ffccce2021-07-28 15:37:02 +0100599 # Exhaustively test combinations of padding on each side of each dimension
600 # - the range of padding values is defined by pad_min and pad_max
601 # - for padding >9, the name format needs to be more distinctive
602 pad_min, pad_max = 0, 1
603 pad_values = [x for x in range(pad_min, pad_max + 1)]
604 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
605 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700606
Les Bell7ffccce2021-07-28 15:37:02 +0100607 for paddings in shape_pad_values:
608 name = "pad"
609 for r in range(rank):
610 before, after = paddings[r]
611 name = f"{name}{before}{after}"
612 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700613
614 return arg_list
615
616 @staticmethod
617 def agPooling(testGen, opName, shapeList, dtype):
618 arg_list = []
619
620 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800621 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700622
Les Bell7aa69f42021-09-20 10:44:07 +0100623 # Generate comprehensive argument lists
624 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
625 paddings = {x for x in itertools.product(*([p_vals] * 4))}
626 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
627 strides = {x for x in itertools.product(*([s_vals] * 2))}
628 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
629 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700630
Les Bell7aa69f42021-09-20 10:44:07 +0100631 # add some oversize argument values
632 bigStride = 7
633 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
634 bigKernel = 6
635 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
636 if max(shape) < 64:
637 # padding must be less than the kernel size
638 bigPadding = bigKernel - 1
639 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700640
Les Bell7aa69f42021-09-20 10:44:07 +0100641 # There are too many parameter combinations, so generate them sparsely
642 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
643 n = 0
644 for s in sorted(list(strides)):
645 for p in sorted(list(paddings)):
646 for k in sorted(list(kernels)):
647 if (n % sparsity == 0
648 # padding must not exceed the kernel size
649 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
650 # the padded shape must exceed the kernel size
651 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
652 ):
653 arg_list.append(
654 (
655 "st{}_kern{}_pad{}".format(
656 "".join([str(x) for x in s]),
657 "".join([str(x) for x in k]),
658 "".join([str(x) for x in p]),
659 ),
660 [s, p, k],
661 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800662 )
Les Bell7aa69f42021-09-20 10:44:07 +0100663 n += 1
664
Eric Kunzee5e26762020-10-13 16:11:07 -0700665 return arg_list
666
667 @staticmethod
668 def agCast(testGen, opName, shapeList, inDtype):
669 arg_list = []
670
671 # Enumerate the output types here
672 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800673 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700674 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800675 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700676 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800677 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700678 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800679 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700680 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800681 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700682 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800683 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700684
685 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800686 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700687
688 return arg_list
689
690 @staticmethod
691 def agRescale(testGen, opName, shapeList, inDtype):
692 arg_list = []
693
694 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100695 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
696 if inDtype == DType.UINT8 and dtype != DType.INT8:
697 # The only output dtype for UINT8 is INT8, skip all other combinations
698 continue
699 if inDtype != DType.INT8 and dtype == DType.UINT8:
700 # The only input dtype for UINT8 is INT8, skip all other combinations
701 continue
702
Kevin Cheng550ccc52021-03-03 11:21:43 -0800703 for scale32 in [False, True]:
704 for double_round in [False, True]:
705 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700706
707 if inDtype == DType.INT48 and scale32:
708 # Illegal condition. Must be scale32=False
709 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100710 if double_round and not scale32:
711 # Illegal condition. ERROR_IF(!scale32 && double_round)
712 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700713
Kevin Cheng550ccc52021-03-03 11:21:43 -0800714 arg_list.append(
715 (
716 "out{}_sc{}_dr{}_pc{}".format(
717 DTypeNames[dtype],
718 int(scale32),
719 int(double_round),
720 int(per_channel),
721 ),
722 [dtype, scale32, double_round, per_channel],
723 )
724 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700725
726 return arg_list
727
Kevin Chengaee1fac2020-11-11 13:54:06 -0800728 @staticmethod
729 def agMul(testGen, opName, shapeList, dtype):
730 arg_list = []
731
732 if dtype is DType.INT32:
733 for p in range(testGen.args.num_rand_permutations):
734
735 shift = testGen.randInt(0, 32)
736
Kevin Cheng550ccc52021-03-03 11:21:43 -0800737 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800738 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100739 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800740
741 return arg_list
742
743 @staticmethod
744 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
745 arg_list = []
746
Kevin Cheng550ccc52021-03-03 11:21:43 -0800747 arg_list.append(("roundTrue", [True]))
748 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800749
750 return arg_list
751
Eric Kunzee5e26762020-10-13 16:11:07 -0700752 # Helper function for reshape. Gets some factors of a larger number.
753 @staticmethod
754 def getFactors(val, start=1):
755 factors = []
756
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100757 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700758 if (val % i) == 0:
759 factors.append(i)
760
761 return factors
762
763 @staticmethod
764 def agReshape(testGen, opName, shapeList, dtype):
765 arg_list = []
766
767 origShape = shapeList[0]
768
769 totalElements = 1
770 for s in origShape:
771 totalElements *= s
772
773 # This code is NOT fast. Fortunately, the numbers are fairly small.
774 factors = TosaArgGen.getFactors(totalElements)
775
776 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100777 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800778 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700779 continue
780
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100781 found = True
782 # escape_counter breaks while loop if it continues on for too long
783 escape_counter = 0
784 while found:
785 newShape = []
786 # Generate newShape ensuring it isn't a duplicate
787 remainingElements = totalElements
788 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100789 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100790 # pick rank-1 factors
791 newShape.append(shuffledFactors[0])
792 remainingElements = remainingElements // shuffledFactors[0]
793 shuffledFactors = testGen.rng.permutation(
794 TosaArgGen.getFactors(remainingElements)
795 )
796 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700797
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100798 # Toss in a -1 sometimes
799 minusOne = testGen.randInt(0, newRank * 4)
800 if minusOne < newRank:
801 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700802
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100803 # Check for duplicates
804 found = False
805 for name, other_shape in arg_list:
806 if other_shape[0] == newShape:
807 found = True
808 break
809
810 escape_counter += 1
811 if escape_counter >= 100:
812 break
813
814 if not found:
815 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
817 return arg_list
818
Eric Kunzee5e26762020-10-13 16:11:07 -0700819 @staticmethod
820 def agTranspose(testGen, opName, shapeList, dtype):
821 arg_list = []
822
823 ifm_shape = shapeList[0]
824
Jeremy Johnsona6185572021-06-21 15:55:35 +0100825 # Get all permutations
826 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700827
Jeremy Johnsona6185572021-06-21 15:55:35 +0100828 # Limit to possible permutations from shape dimension or argument setting
829 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
Jeremy Johnsona6185572021-06-21 15:55:35 +0100831 # Get random permutation generator that uses all permutations
832 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700833
Jeremy Johnsona6185572021-06-21 15:55:35 +0100834 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700835 arg_list = [
836 ("perm{}".format(p), [random_permutations[p].tolist()])
837 for p in range(limit)
838 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700839 return arg_list
840
841 @staticmethod
842 def agSlice(testGen, opName, shapeList, dtype):
843 arg_list = []
844
845 ifm_shape = shapeList[0]
846 rank = len(ifm_shape)
847
848 for p in range(testGen.args.num_rand_permutations):
849 begin = []
850 size = []
851
Kevin Cheng550ccc52021-03-03 11:21:43 -0800852 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700853
854 for i in range(rank):
855 if ifm_shape[i] > 1:
856 begin.append(testGen.randInt(0, ifm_shape[i]))
857 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
858
859 # Invalid slice size?
860 if size[i] == 0:
861 valid = False
862 else:
863 begin.append(0)
864 size.append(1)
865
866 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800867 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700868 return arg_list
869
870 @staticmethod
871 def agTile(testGen, opName, shapeList, dtype):
872 arg_list = []
873
874 ifm_shape = shapeList[0]
875 rank = len(ifm_shape)
876
877 for p in range(testGen.args.num_rand_permutations):
878
879 # Pick a few random, but small multiple values
880 # because otherwise this has a tendency to generate
881 # enormous tensors
882 multiples = []
883 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100884 if ifm_shape[i] > 1000:
885 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
886 multiples.append(1)
887 elif max(ifm_shape) > 1000:
888 multiples.append(2)
889 else:
890 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800891 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700892
893 return arg_list
894
895 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100896 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700897 arg_list = []
898
899 ifm_shape = shapeList[0]
900
901 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
902
903 # Exclude illegal {mode, type} configurations. Pick legal output types
904 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100905 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700906 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800907 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700908 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100909 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700910 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800911 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800912 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800913 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700914 else:
915 continue
916
917 for outputDType in outputDTypeList:
918 for perm in range(testGen.args.num_rand_permutations):
919
920 # Randomly generate legal output dimensions and shift
921 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +0100922 # A output_dim of 1 will cause offset to exceed allowed range
923 # so minimum value 2 produced below
924 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
925 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
926 output_dims[0] += 1
927 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
928 output_dims[1] += 1
929
Kevin Cheng77d0f762020-11-24 10:26:32 -0800930 in_center_h = (ifm_shape[1] - 1) / 2.0
931 in_center_w = (ifm_shape[2] - 1) / 2.0
932 out_center_h = (output_dims[0] - 1) / 2.0
933 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700934
Kevin Cheng77d0f762020-11-24 10:26:32 -0800935 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
936 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
937 fp_offset_y = in_center_h - fp_stride_y * out_center_h
938 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700939
Kevin Cheng77d0f762020-11-24 10:26:32 -0800940 if outputDType == DType.FLOAT:
941 shift = 0
942 stride = [0, 0]
943 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800944 stride_fp = [fp_stride_y, fp_stride_x]
945 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +0100946
947 if error_name is not None:
948 shift, stride, stride_fp, offset, offset_fp = TosaErrorIfArgGen.eiResizeErrorIf(
949 testGen,
950 error_name,
951 shapeList,
952 outputDType,
953 shift,
954 stride,
955 stride_fp,
956 offset,
957 offset_fp
958 )
959
Kevin Cheng550ccc52021-03-03 11:21:43 -0800960 arg_list.append(
961 (
962 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100963 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800964 output_dims[0],
965 output_dims[1],
966 testGen.typeStr(outputDType),
967 stride_fp[0],
968 stride_fp[1],
969 offset_fp[0],
970 offset_fp[1],
971 ),
972 [
973 m,
974 stride,
975 offset,
976 shift,
977 stride_fp,
978 offset_fp,
979 output_dims,
980 dtype,
981 outputDType,
982 ],
983 )
984 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800985 else:
986 shift = 11
987 unit = float(1 << shift)
988 stride_y = int(round(fp_stride_y * unit))
989 stride_x = int(round(fp_stride_x * unit))
990 offset_y = int(round(fp_offset_y * unit))
991 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700992
Kevin Cheng550ccc52021-03-03 11:21:43 -0800993 while (
Matthew Haddone86fd342021-09-07 16:12:21 +0100994 stride_y >= (16 << shift)
995 or stride_x >= (16 << shift)
996 or offset_y >= (16 << shift)
997 or offset_x >= (16 << shift)
998 or offset_y <= (-16 << shift)
999 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001000 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001001 shift = shift - 1
1002 unit = float(1 << shift)
1003 stride_y = int(round(fp_stride_y * unit))
1004 stride_x = int(round(fp_stride_x * unit))
1005 offset_y = int(round(fp_offset_y * unit))
1006 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001007
Kevin Cheng550ccc52021-03-03 11:21:43 -08001008 stride = [stride_y, stride_x]
1009 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001010
1011 stride_fp = [0.0, 0.0]
1012 offset_fp = [0.0, 0.0]
1013
Matthew Haddone86fd342021-09-07 16:12:21 +01001014 if error_name is not None:
1015 shift, stride, stride_fp, offset, offset_fp = TosaErrorIfArgGen.eiResizeErrorIf(
1016 testGen,
1017 error_name,
1018 shapeList,
1019 outputDType,
1020 shift,
1021 stride,
1022 stride_fp,
1023 offset,
1024 offset_fp
1025 )
1026
Kevin Cheng550ccc52021-03-03 11:21:43 -08001027 arg_list.append(
1028 (
1029 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Les Bell33d837e2021-08-10 08:34:43 +01001030 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001031 shift,
1032 output_dims[0],
1033 output_dims[1],
1034 testGen.typeStr(outputDType),
1035 stride[0],
1036 stride[1],
1037 offset[0],
1038 offset[1],
1039 ),
1040 [
1041 m,
1042 stride,
1043 offset,
1044 shift,
1045 stride_fp,
1046 offset_fp,
1047 output_dims,
1048 dtype,
1049 outputDType,
1050 ],
1051 )
1052 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001053
1054 return arg_list
1055
1056 def agCondIf(testGen, opName, shapeList, dtype):
1057 # CondIf generates the condition values here.
1058 # Convert to tensors in the build function, along with the
1059 # then and else blocks
1060 arg_list = []
1061
1062 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001063 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001064
1065 return arg_list
1066
1067 def agWhileLoop(testGen, opName, shapeList, dtype):
1068 # While loop: 0 iterations, 1, more than 1
1069 arg_list = []
1070
1071 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001072 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001073
1074 return arg_list
1075
Matthew Haddone86fd342021-09-07 16:12:21 +01001076class TosaErrorIfArgGen:
1077
1078 @staticmethod
1079 def eiResizeErrorIf(testGen, error_name, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
1080
1081 if outputDType == DType.FLOAT:
1082 if error_name == ErrorIf.StrideSmallerEqualZero:
1083 stride_fp = testGen.rng.random(size=[2]) - 2
1084 elif error_name == ErrorIf.ShiftNotZero:
1085 shift = testGen.rng.integers(1, 5)
1086 elif error_name == ErrorIf.StrideLargerDimension:
1087 shape = shapeList[0]
1088 transform_height = testGen.rng.choice([False, True])
1089 if transform_height:
1090 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1091 else:
1092 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1093 else:
1094 if error_name == ErrorIf.StrideSmallerEqualZero:
1095 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1096 elif error_name == ErrorIf.ShiftSmallerOne:
1097 shift = testGen.rng.integers(-3, 1)
1098 if shift <= 0:
1099 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1100 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1101 else:
1102 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1103 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1104 elif error_name == ErrorIf.ShiftLargerEleven:
1105 shift = np.int16(testGen.rng.integers(12, 15))
1106 elif error_name == ErrorIf.StrideLargerDimension:
1107 shape = shapeList[0]
1108 transform_height = testGen.rng.choice([False, True])
1109 if transform_height:
1110 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1111 else:
1112 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1113 elif error_name == ErrorIf.StrideLargerEqualMax:
1114 stride = [(16 << shift) + 1, (16 << shift) + 1]
1115 elif error_name == ErrorIf.OffsetLargerEqualMax:
1116 offset = [(16 << shift) + 1, (16 << shift) + 1]
1117 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1118 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1119
1120 return shift, stride, stride_fp, offset, offset_fp
1121
1122
1123class TosaErrorValidator:
1124
1125
1126 @staticmethod
1127 def evMaxDimExceeded(check=False, **kwargs):
1128 error_name = ErrorIf.MaxDimExceeded
1129 param_reqs = {"rank": [4,4], "dtype": [DType.INT8], "shape": [[1, 16584, 5, 1]]}
1130 error_result = False
1131 error_reason = "At least one maximum dimension is larger than 16384"
1132
1133 if check:
1134 input_shape = kwargs['input_shape'].shape
1135 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1136 if ((input_shape[1] > 16384) or
1137 (input_shape[2] > 16384) or
1138 (output_shape[0] > 16384) or
1139 (output_shape[1] > 16384)):
1140 error_result = True
1141
1142 info_dict = {
1143 "error_name": error_name,
1144 "error_result": error_result,
1145 "error_reason": error_reason,
1146 "param_reqs": param_reqs
1147 }
1148 return info_dict
1149
1150 @staticmethod
1151 def evStrideSmallerEqualZero(check=False, **kwargs):
1152 error_name = ErrorIf.StrideSmallerEqualZero
1153 param_reqs = {"rank": None, "dtype": None, "shape": None}
1154 error_result = False
1155 error_reason = "Stride value smaller than or equal zero"
1156
1157 if check:
1158 input_dtype = kwargs['input_dtype']
1159 if input_dtype == DType.FLOAT:
1160 stride = kwargs['stride_fp']
1161 else:
1162 stride = kwargs['stride']
1163
1164 if min(stride) <= 0:
1165 error_result = True
1166
1167 info_dict = {
1168 "error_name": error_name,
1169 "error_result": error_result,
1170 "error_reason": error_reason,
1171 "param_reqs": param_reqs
1172 }
1173 return info_dict
1174
1175 @staticmethod
1176 def evStrideLargerEqualMax(check=False, **kwargs):
1177 error_name = ErrorIf.StrideLargerEqualMax
1178 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1179 error_result = False
1180 error_reason = "Stride value larger than or equal to maximum value"
1181
1182 if check:
1183 shift = kwargs['shift']
1184 input_dtype = kwargs['input_dtype']
1185 stride = kwargs['stride']
1186 if input_dtype in [DType.INT8, DType.INT16]:
1187 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1188 error_result = True
1189 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1190 error_result = True
1191
1192 info_dict = {
1193 "error_name": error_name,
1194 "error_result": error_result,
1195 "error_reason": error_reason,
1196 "param_reqs": param_reqs
1197 }
1198 return info_dict
1199
1200
1201 @staticmethod
1202 def evStrideLargerDimension(check=False, **kwargs):
1203 error_name = ErrorIf.StrideLargerDimension
1204 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1205 error_result = False
1206 error_reason = "Stride value larger than or equal to H/W dimension"
1207
1208 if check:
1209 shape = kwargs['input_shape'].shape
1210 input_dtype = kwargs['input_dtype']
1211 stride = kwargs['stride_fp']
1212
1213 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1214 error_result = True
1215
1216 info_dict = {
1217 "error_name": error_name,
1218 "error_result": error_result,
1219 "error_reason": error_reason,
1220 "param_reqs": param_reqs
1221 }
1222 return info_dict
1223
1224
1225 @staticmethod
1226 def evOffsetSmallerEqualMin(check=False, **kwargs):
1227 error_name = ErrorIf.OffsetSmallerEqualMin
1228 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1229 error_result = False
1230 error_reason = "Offset value smaller than or equal to minimum value"
1231
1232 if check:
1233 shift = kwargs['shift']
1234 input_dtype = kwargs['input_dtype']
1235 if input_dtype == DType.FLOAT:
1236 offset = kwargs['offset_fp']
1237 else:
1238 offset = kwargs['offset']
1239
1240 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1241 error_result = True
1242 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1243 error_result = True
1244
1245 info_dict = {
1246 "error_name": error_name,
1247 "error_result": error_result,
1248 "error_reason": error_reason,
1249 "param_reqs": param_reqs
1250 }
1251 return info_dict
1252
1253 @staticmethod
1254 def evOffsetLargerEqualMax(check=False, **kwargs):
1255 error_name = ErrorIf.OffsetLargerEqualMax
1256 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1257 error_result = False
1258 error_reason = "Offset value larger than or equal to maximum value"
1259
1260 if check:
1261 shift = kwargs['shift']
1262 input_dtype = kwargs['input_dtype']
1263 if input_dtype == DType.FLOAT:
1264 offset = kwargs['offset_fp']
1265 else:
1266 offset = kwargs['offset']
1267
1268 if shift >= 0:
1269 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1270 error_result = True
1271
1272 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1273 error_result = True
1274 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1275 error_result = True
1276
1277 info_dict = {
1278 "error_name": error_name,
1279 "error_result": error_result,
1280 "error_reason": error_reason,
1281 "param_reqs": param_reqs
1282 }
1283 return info_dict
1284
1285 @staticmethod
1286 def evShiftNotZero(check=False, **kwargs):
1287 error_name = ErrorIf.ShiftNotZero
1288 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1289 error_result = False
1290 error_reason = "Shift value must be zero for float input"
1291
1292 if check:
1293 shift = kwargs['shift']
1294 input_dtype = kwargs['input_dtype']
1295 if input_dtype == DType.FLOAT and shift != 0:
1296 error_result = True
1297
1298 info_dict = {
1299 "error_name": error_name,
1300 "error_result": error_result,
1301 "error_reason": error_reason,
1302 "param_reqs": param_reqs
1303 }
1304 return info_dict
1305
1306
1307 @staticmethod
1308 def evShiftSmallerOne(check=False, **kwargs):
1309 error_name = ErrorIf.ShiftSmallerOne
1310 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1311 error_result = False
1312 error_reason = "Shift value smaller than one"
1313
1314 if check:
1315 shift = kwargs['shift']
1316 input_dtype = kwargs['input_dtype']
1317 if shift < 1 and input_dtype != DType.FLOAT:
1318 error_result = True
1319
1320 info_dict = {
1321 "error_name": error_name,
1322 "error_result": error_result,
1323 "error_reason": error_reason,
1324 "param_reqs": param_reqs
1325 }
1326 return info_dict
1327
1328 @staticmethod
1329 def evShiftLargerEleven(check=False, **kwargs):
1330 error_name = ErrorIf.ShiftLargerEleven
1331 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1332 error_result = False
1333 error_reason = "Shift value larger than eleven"
1334
1335 if check:
1336 shift = kwargs['shift']
1337 if shift > 11:
1338 error_result = True
1339
1340 info_dict = {
1341 "error_name": error_name,
1342 "error_result": error_result,
1343 "error_reason": error_reason,
1344 "param_reqs": param_reqs
1345 }
1346 return info_dict
1347
1348
Matthew Haddonb724efc2021-08-25 16:40:29 +01001349class TosaInvalidValidator:
1350
1351 @staticmethod
1352 def ivWrongDataTypeOrModeResize(**kwargs):
1353 input_dtype = kwargs["input_dtype"]
1354 args = kwargs["args"]
1355 mode = args[0]
1356 stride = args[1]
1357 stride_fp = args[4]
1358 output_dtype = args[8]
1359
1360 if mode == ResizeMode.BILINEAR:
1361 # Invalid output data type / Invalid input datatype
1362 return (
1363 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
1364 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
1365 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
1366 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1367 )
1368 elif mode == ResizeMode.NEAREST:
1369 # Invalid output data type / Invalid input datatype
1370 return (
1371 (input_dtype != output_dtype) or
1372 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1373 )
1374 else:
1375 # Invalid resize mode
1376 return True
1377
1378 @staticmethod
1379 def ivBadStride(**kwargs):
1380 input_dtype = kwargs["input_dtype"]
1381 args = kwargs["args"]
1382 stride_x = args[1][0]
1383 stride_y = args[1][1]
1384 stride_fp_x = args[4][0]
1385 stride_fp_y = args[4][1]
1386
1387 if input_dtype == DType.FLOAT:
1388 if stride_fp_x <= 0 or stride_fp_y <= 0:
1389 # Negative or zero stride
1390 return True
1391 else:
1392 if stride_x <= 0 or stride_y <= 0:
1393 # Negative or zero stride
1394 return True
1395 return False
1396
1397
1398
1399
1400 @staticmethod
1401 def ivHeightWidthSmallerZero(**kwargs):
1402 opName = kwargs['opName']
1403
1404 inputShapes = kwargs['shapeList']
1405 input = inputShapes[0]
1406 if not opName.endswith("pool2d"):
1407 filter = inputShapes[1]
1408
1409 args = kwargs['args']
1410 strides = args[0]
1411 padding = args[1]
1412 dilations = args[2]
1413 if opName.endswith("pool2d"):
1414 kernel = args[2]
1415
1416 if opName.startswith('conv2d'):
1417 h = (
1418 input[1]
1419 - filter[1]
1420 - (filter[1] - 1) * (dilations[0] - 1)
1421 + padding[0]
1422 + padding[1]
1423 ) // strides[0] + 1
1424
1425 w = (
1426 input[2]
1427 - filter[2]
1428 - (filter[2] - 1) * (dilations[1] - 1)
1429 + padding[2]
1430 + padding[3]
1431 ) // strides[1] + 1
1432 elif opName.startswith("depthwise_conv2d"):
1433 h = (
1434 input[1]
1435 - filter[0]
1436 - (filter[0] - 1) * (dilations[0] - 1)
1437 + padding[0]
1438 + padding[1]
1439 ) // strides[0] + 1
1440
1441 w = (
1442 input[2]
1443 - filter[1]
1444 - (filter[1] - 1) * (dilations[1] - 1)
1445 + padding[2]
1446 + padding[3]
1447 ) // strides[1] + 1
1448 elif opName.endswith("pool2d"):
1449 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
1450 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
1451 else:
1452 assert False, "Unrecognized Op"
1453
1454 if h <= 0 or w <= 0:
1455 # Invalid parameter combination
1456 return True
1457 return False
1458
1459 @staticmethod
1460 def ivNonPositiveOutputShape(**kwargs):
1461 args = kwargs['args']
1462 output_shape = args[3]
1463 if output_shape[1] <= 0 or output_shape[2] <= 0:
1464 # Negative output shape
1465 return True
1466 return False
1467
1468
Kevin Cheng550ccc52021-03-03 11:21:43 -08001469
Eric Kunzee5e26762020-10-13 16:11:07 -07001470class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001471 # Maximum rank of tensor supported by test generator.
1472 TOSA_TENSOR_MAX_RANK = 6
1473
Eric Kunzee5e26762020-10-13 16:11:07 -07001474 def __init__(self, args):
1475 self.args = args
1476 self.basePath = args.output_dir
1477 self.random_seed = args.random_seed
1478 self.ser = None
1479 self.rng = np.random.default_rng(self.random_seed)
1480 self.createDynamicOpLists()
1481 self.initOpListDefaults()
1482 self.quantGen = TosaQuantGen()
1483 # Force makeShape to do a specific starting shape
1484 self.targetted_shape = None
1485
1486 def createSerializer(self, opName, testPath):
1487 self.testPath = os.path.join(opName, testPath)
1488
1489 fullPath = os.path.join(self.basePath, self.testPath)
1490 os.makedirs(fullPath, exist_ok=True)
1491 self.ser = ts.TosaSerializer(fullPath)
1492
1493 def getSerializer(self):
1494 return self.ser
1495
1496 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001497 with open(
1498 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
1499 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07001500 fd.write(self.ser.serialize())
1501
Kevin Cheng550ccc52021-03-03 11:21:43 -08001502 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
1503 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07001504
Matthew Haddon74567092021-07-16 15:38:20 +01001505 def resetRNG(self, seed=None):
1506 if seed == None:
1507 seed = self.random_seed + 1
1508 self.rng = np.random.default_rng(seed)
1509
Eric Kunzee5e26762020-10-13 16:11:07 -07001510 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07001511 if dtype == DType.BOOL:
1512 np_dt = np.bool
1513 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07001514 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001515 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001516 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001517 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001518 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
1519 elif dtype == DType.UINT8:
1520 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001521 elif dtype == DType.INT16:
1522 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
1523 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001524 return np.int32(
1525 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
1526 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001527 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001528 return np.int64(
1529 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
1530 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001531 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001532 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001533 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001534 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001535
Kevin Cheng989cb052021-04-28 16:29:44 -07001536 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001537 placeholders = []
1538
Kevin Cheng989cb052021-04-28 16:29:44 -07001539 assert len(shape_list) == len(dtype_list)
1540
1541 for idx, shape in enumerate(shape_list):
1542 arr = self.getRandTensor(shape, dtype_list[idx])
1543 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001544
1545 return placeholders
1546
Kevin Cheng989cb052021-04-28 16:29:44 -07001547 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001548 consts = []
1549
Kevin Cheng989cb052021-04-28 16:29:44 -07001550 assert len(shape_list) == len(dtype_list)
1551
1552 for idx, shape in enumerate(shape_list):
1553 arr = self.getRandTensor(shape, dtype_list[idx])
1554 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001555
1556 return consts
1557
1558 def makeShape(self, rank):
1559 if self.targetted_shape:
1560 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001561 return np.int32(
1562 self.rng.integers(
1563 low=self.args.tensor_shape_range[0],
1564 high=self.args.tensor_shape_range[1],
1565 size=rank,
1566 )
1567 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001568
1569 def setTargetShape(self, shape):
1570 self.targetted_shape = shape
1571
1572 def randInt(self, low=0, high=256):
1573 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
1574
1575 def getRandNumberDType(self, dtype):
1576 if dtype == DType.FLOAT:
1577 return self.rng.random()
1578 elif dtype == DType.BOOL:
1579 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07001580 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001581 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001582 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001583 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001584 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07001585 elif dtype == DType.INT16:
1586 low, high = (-32768, 32768)
1587 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001588 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001589 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001590 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001591 # Special size
1592 return np.int64(self.rng.integers(low, high, size=1))[0]
1593 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001594 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001595
1596 return np.int32(self.rng.integers(low, high, size=1))[0]
1597
1598 def shapeStr(self, shape):
1599
1600 sStr = []
1601 # Convert to strings
1602 for i in shape:
1603 sStr.append(str(i))
1604
Kevin Cheng550ccc52021-03-03 11:21:43 -08001605 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001606
1607 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001608 if isinstance(t, list):
1609 assert len(t) >= 2
1610 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001611 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001612 if t == DType.BOOL:
1613 return "b"
1614 elif t == DType.INT4:
1615 return "i4"
1616 elif t == DType.INT8:
1617 return "i8"
1618 elif t == DType.UINT8:
1619 return "u8"
1620 elif t == DType.INT16:
1621 return "i16"
1622 elif t == DType.INT32:
1623 return "i32"
1624 elif t == DType.INT48:
1625 return "i48"
1626 elif t == DType.FLOAT:
1627 return "float"
1628 else:
1629 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001630
1631 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001632 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001633 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001634 return 4
1635 elif t == DType.INT8:
1636 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001637 elif t == DType.UINT8:
1638 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001639 elif t == DType.INT16:
1640 return 16
1641 elif t == DType.INT32:
1642 return 32
1643 elif t == DType.INT48:
1644 return 48
1645 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001646 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001647
1648 # Argument generators
1649 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1650 # Where the string descriptor is used to generate the test name and
1651 # The build_fcn_arg_list is expanded and passed to the operator test
1652 # build function
1653
Kevin Cheng550ccc52021-03-03 11:21:43 -08001654 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001655 result_tens = OutputShaper.unaryOp(self.ser, a)
1656 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1657 return result_tens
1658
1659 def build_binary_broadcast(self, op, a, b):
1660 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1661 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1662 return result_tens
1663
1664 def build_binary_nonbroadcast(self, op, a, b):
1665 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1666 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1667 return result_tens
1668
Kevin Chengaee1fac2020-11-11 13:54:06 -08001669 def build_arithmetic_right_shift(self, op, a, b, round):
1670 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1671
1672 attr = ts.TosaSerializerAttribute()
1673 attr.ArithmeticRightShiftAttribute(round)
1674
1675 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1676 return result_tens
1677
1678 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001679 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1680
1681 # Special for multiply:
1682 # Force the result to INT32 for INT types
1683 if a.dtype != DType.FLOAT:
1684 result_tens.setDtype(DType.INT32)
1685
Kevin Chengaee1fac2020-11-11 13:54:06 -08001686 attr = ts.TosaSerializerAttribute()
1687 attr.MulAttribute(shift)
1688
1689 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001690 return result_tens
1691
1692 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001693 # Constant size depending on type, random values
1694 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07001695 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001696 table_arr = self.getRandTensor([513], table_dtype)
1697 else:
1698 assert a.dtype == DType.INT8
1699 table_dtype = DType.INT8
1700 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001701
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001702 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1703 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001704 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1705
1706 return result_tens
1707
1708 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001709 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1710 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001711 return result_tens
1712
1713 def build_comparison(self, op, a, b):
1714 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1715 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1716 return result_tens
1717
1718 def build_argmax(self, op, a, axis):
1719 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1720
1721 attr = ts.TosaSerializerAttribute()
1722 attr.AxisAttribute(axis)
1723
1724 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1725 return result_tens
1726
Matthew Haddonb724efc2021-08-25 16:40:29 +01001727 def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001728 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1729
1730 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001731 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001732
1733 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1734 return result_tens
1735
1736 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001737 assert len(padding) == 4
1738 result_tens = OutputShaper.conv2dOp(
1739 self.ser, ifm, filter, strides, padding, dilations
1740 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001741
1742 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001743 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
Kevin Cheng550ccc52021-03-03 11:21:43 -08001745 self.ser.addOperator(
1746 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1747 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001748 return result_tens
1749
Kevin Cheng1533b852021-09-01 12:51:58 -07001750 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
1751 assert len(padding) == 6
1752 result_tens = OutputShaper.conv3dOp(
1753 self.ser, ifm, filter, strides, padding, dilations
1754 )
1755
1756 attr = ts.TosaSerializerAttribute()
1757 attr.ConvAttribute(padding, strides, dilations)
1758
1759 self.ser.addOperator(
1760 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1761 )
1762 return result_tens
1763
Kevin Cheng550ccc52021-03-03 11:21:43 -08001764 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001765 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001766 ):
1767 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001768 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1769
1770 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001771 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07001772
Kevin Cheng550ccc52021-03-03 11:21:43 -08001773 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001774 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001775 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001776 return result_tens
1777
Kevin Cheng550ccc52021-03-03 11:21:43 -08001778 def build_depthwise_conv2d(
1779 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1780 ):
1781 result_tens = OutputShaper.depthwiseConv2dOp(
1782 self.ser, ifm, filter, strides, padding, dilations
1783 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001784
1785 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001786 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001787
Kevin Cheng550ccc52021-03-03 11:21:43 -08001788 self.ser.addOperator(
1789 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1790 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001791 return result_tens
1792
1793 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1794 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1795
Kevin Cheng550ccc52021-03-03 11:21:43 -08001796 self.ser.addOperator(
1797 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1798 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001799 return result_tens
1800
1801 def build_matmul(self, op, a, b, qinfo):
1802 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1803 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1804 return result_tens
1805
1806 def build_reduce(self, op, a, axis):
1807 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1808
1809 attr = ts.TosaSerializerAttribute()
1810 attr.AxisAttribute(axis)
1811
1812 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1813 return result_tens
1814
1815 def build_clamp(self, op, a):
1816 result_tens = OutputShaper.unaryOp(self.ser, a)
1817
1818 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01001819 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001820
1821 if a.dtype == DType.FLOAT:
1822 attr.ClampAttribute(0, 0, min(v), max(v))
1823 else:
1824 attr.ClampAttribute(min(v), max(v), 0, 0)
1825
1826 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1827 return result_tens
1828
1829 def build_leaky_relu(self, op, a):
1830 result_tens = OutputShaper.unaryOp(self.ser, a)
1831 attr = ts.TosaSerializerAttribute()
1832
1833 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1834
1835 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1836 return result_tens
1837
1838 # Needs an additional type/input
1839 def build_prelu(self, op, a):
1840 result_tens = OutputShaper.unaryOp(self.ser, a)
1841
1842 self.ser.addOperator(op, [a.name], [result_tens.name])
1843 return result_tens
1844
Eric Kunzee5e26762020-10-13 16:11:07 -07001845 def build_sigmoid(self, op, a):
1846 result_tens = OutputShaper.unaryOp(self.ser, a)
1847 self.ser.addOperator(op, [a.name], [result_tens.name])
1848 return result_tens
1849
1850 def build_tanh(self, op, a):
1851 result_tens = OutputShaper.unaryOp(self.ser, a)
1852 self.ser.addOperator(op, [a.name], [result_tens.name])
1853 return result_tens
1854
Matthew Haddon818ab902021-07-27 09:12:49 +01001855 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07001856 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001857
1858 # To store variable length list of input tensors we need to store axis along with it
1859 axis = a[-1]
1860 a = a[:-1]
1861
1862 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07001863
1864 attr = ts.TosaSerializerAttribute()
1865 attr.AxisAttribute(axis)
1866
Matthew Haddon818ab902021-07-27 09:12:49 +01001867 input_tensor_names = []
1868 for tensor in a:
1869 input_tensor_names.append(tensor.name)
1870
1871 self.ser.addOperator(op, input_tensor_names, [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001872
1873 def build_pad(self, op, a, padding, qinfo):
1874 result_tens = OutputShaper.padOp(self.ser, a, padding)
1875
1876 # Need to turn the padding array into a TOSA tensor here.
1877 # This is one of the few tensor operands that does not get
1878 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001879 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001880
Kevin Cheng550ccc52021-03-03 11:21:43 -08001881 self.ser.addOperator(
1882 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1883 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001884 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001885
1886 def build_reshape(self, op, a, newShape):
1887 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1888
1889 attr = ts.TosaSerializerAttribute()
1890 attr.ReshapeAttribute(newShape)
1891
1892 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1893 return result_tens
1894
1895 def build_reverse(self, op, a, axis):
1896 result_tens = OutputShaper.unaryOp(self.ser, a)
1897
1898 attr = ts.TosaSerializerAttribute()
1899 attr.AxisAttribute(axis)
1900
1901 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1902 return result_tens
1903
1904 def build_transpose(self, op, a, perms):
1905 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1906
Kevin Cheng550ccc52021-03-03 11:21:43 -08001907 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
1909 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1910 return result_tens
1911
1912 def build_slice(self, op, a, begin, size):
1913 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1914
1915 attr = ts.TosaSerializerAttribute()
1916 attr.SliceAttribute(begin, size)
1917
1918 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1919 return result_tens
1920
1921 def build_tile(self, op, a, multiples):
1922 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1923
1924 attr = ts.TosaSerializerAttribute()
1925 attr.TileAttribute(multiples)
1926
1927 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1928 return result_tens
1929
Kevin Cheng77d0f762020-11-24 10:26:32 -08001930 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001931
1932 # Create a new indicies tensor
1933 # here with data that doesn't exceed the dimensions of the values tensor
1934
Kevin Cheng550ccc52021-03-03 11:21:43 -08001935 K = values.shape[1] # K
1936 W = self.randInt(
1937 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1938 ) # W
1939 indicies_arr = np.int32(
1940 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1941 ) # (N, W)
1942 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Kevin Cheng77d0f762020-11-24 10:26:32 -08001944 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001945
Kevin Cheng77d0f762020-11-24 10:26:32 -08001946 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001947
1948 return result_tens
1949
Kevin Cheng77d0f762020-11-24 10:26:32 -08001950 def build_scatter(self, op, values_in, input):
1951
1952 # Create a new indicies tensor
1953 # here with data that doesn't exceed the dimensions of the values_in tensor
1954
Kevin Cheng550ccc52021-03-03 11:21:43 -08001955 K = values_in.shape[1] # K
1956 W = input.shape[1] # W
1957 indicies_arr = np.int32(
1958 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1959 ) # (N, W)
1960 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001961
1962 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1963
Kevin Cheng550ccc52021-03-03 11:21:43 -08001964 self.ser.addOperator(
1965 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1966 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001967
1968 return result_tens
1969
Kevin Cheng550ccc52021-03-03 11:21:43 -08001970 def build_resize(
1971 self,
1972 op,
1973 input,
1974 mode,
1975 stride,
1976 offset,
1977 shift,
1978 stride_fp,
1979 offset_fp,
1980 output_dims,
1981 input_dtype,
1982 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001983 validator_fcns,
1984 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001985 ):
1986 result_tens = OutputShaper.resizeOp(
1987 self.ser,
1988 input,
1989 mode,
1990 stride,
1991 offset,
1992 shift,
1993 stride_fp,
1994 offset_fp,
1995 output_dims,
1996 input_dtype,
1997 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001998 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08001999 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002000
Matthew Haddone86fd342021-09-07 16:12:21 +01002001 # Check ERROR_IF statements
2002 for val_fcn in validator_fcns:
2003 val_result = val_fcn(
2004 check=True,
2005 shift=shift,
2006 input_dtype=input_dtype,
2007 input_shape=input,
2008 output_shape=output_dims,
2009 offset=offset,
2010 offset_fp=offset_fp,
2011 stride=stride,
2012 stride_fp=stride_fp)
2013
2014 validator_name = val_result['error_name']
2015 error_result = val_result['error_result']
2016 error_reason = val_result['error_reason']
2017
2018 if error_result:
2019 if error_name == validator_name:
2020 self.ser.setExpectedReturnCode(2, error_reason)
2021 else:
2022 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
2023 return None # Return None to delete test if wrong ERROR_IF is hit
2024
2025
Eric Kunzee5e26762020-10-13 16:11:07 -07002026 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08002027
Kevin Cheng550ccc52021-03-03 11:21:43 -08002028 attr.ResizeAttribute(
2029 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
2030 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002031
2032 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
2033 return result_tens
2034
2035 def build_identityn(self, op, val, val2):
2036
Kevin Cheng550ccc52021-03-03 11:21:43 -08002037 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002038 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002039 self.ser.addOperator(
2040 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2041 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002042 return result_tens
2043
Kevin Cheng17e92022021-10-01 14:33:33 -07002044 def build_const(self, op, val):
2045 self.ser.addOutputTensor(val)
2046 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07002047
2048 # Type Conversion
2049 def build_cast(self, op, val, out_dtype):
2050 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2051 self.ser.addOperator(op, [val.name], [result_tens.name])
2052 return result_tens
2053
2054 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
2055 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2056
2057 if per_channel:
2058 nc = val.shape[-1]
2059 else:
2060 nc = 1
2061
2062 in_type_width = self.typeWidth(val.dtype)
2063 out_type_width = self.typeWidth(out_dtype)
2064
Kevin Cheng3a478572021-01-22 17:21:02 -08002065 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002066 input_zp = self.randInt(-128, 128)
2067 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002068 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002069 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002070 in_type_width = in_type_width + 1
2071 else:
2072 input_zp = 0
2073
Kevin Cheng3a478572021-01-22 17:21:02 -08002074 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002075 output_zp = self.randInt(-128, 128)
2076 out_type_width = out_type_width + 1
2077 elif out_dtype == DType.UINT8:
2078 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002079 out_type_width = out_type_width + 1
2080 else:
2081 output_zp = 0
2082
2083 # Calculate scale based on:
2084 # scale = a *(2^output_width)/(2^input_width))
2085
2086 a = np.float32(self.rng.random(size=[nc]))
2087 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2088
2089 if scale32:
2090 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002091 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002092 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2093 else:
2094 # Cap the scaling at 2^15 - 1 for scale16
2095 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2096
Kevin Cheng550ccc52021-03-03 11:21:43 -08002097 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002098
2099 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2100 shift_arr = np.int32(np.zeros(shape=[nc]))
2101
2102 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002103 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2104 scale_arr[i], scale32
2105 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002106
Kevin Cheng550ccc52021-03-03 11:21:43 -08002107 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07002108
2109 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002110 attr.RescaleAttribute(
2111 input_zp,
2112 output_zp,
2113 multiplier_arr,
2114 shift_arr,
2115 scale32,
2116 double_round,
2117 per_channel,
2118 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002119
2120 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
2121 return result_tens
2122
2123 def build_cond_if_const(self, op, then_tens, else_tens, cond):
2124 # For cond_if with constants, we're supplied with then/else tensors that we ignore
2125 # (except for the generated shap) and the condition. Build Then/Else blocks
2126 # and fill them with const nodes for the body.
2127
2128 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002129 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002130
2131 # Make then/else tensors
2132 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01002133 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2134 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002135
2136 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002137 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002138
2139 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002140 then_block = "THEN_BLOCK"
2141 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002142 attr = ts.TosaSerializerAttribute()
2143 attr.CondIfAttribute(then_block, else_block)
2144
2145 # Finally, build the op and the two blocks
2146 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
2147
2148 self.ser.startBasicBlock(then_block)
2149 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002150 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002151 self.ser.addOutputTensor(then_tens)
2152
2153 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002154 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002155 self.ser.addOutputTensor(else_tens)
2156
2157 return result_tens
2158
2159 def build_cond_if_binary(self, op, a, b, cond):
2160 # For cond_if with a binary op in the then/else blocks, take a and b and
2161 # alternately add or subtract them based on the condition
2162
2163 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002164 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002165
Kevin Cheng550ccc52021-03-03 11:21:43 -08002166 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002167
2168 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002169 then_block = "THEN_BLOCK"
2170 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002171 attr = ts.TosaSerializerAttribute()
2172 attr.CondIfAttribute(then_block, else_block)
2173
2174 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002175 self.ser.addOperator(
2176 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
2177 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002178
2179 self.ser.startBasicBlock(then_block)
2180 self.ser.addInputTensor(a)
2181 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002182 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002183 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
2184
2185 self.ser.startBasicBlock(else_block)
2186 self.ser.addInputTensor(a)
2187 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002188 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002189 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
2190
2191 return result_tens
2192
2193 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002194 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002195
Kevin Cheng550ccc52021-03-03 11:21:43 -08002196 cond_block = "COND_BLOCK"
2197 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002198
2199 attr = ts.TosaSerializerAttribute()
2200 attr.WhileLoopAttribute(cond_block, body_block)
2201
2202 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002203 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002204 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002205 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002206
2207 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002208 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2209 a_out = self.ser.addIntermediate(a.shape, a.dtype)
2210 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002211
2212 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002213 self.ser.addOperator(
2214 op,
2215 [iter.name, a.name, acc.name],
2216 [iter_out.name, a_out.name, acc_out.name],
2217 attr,
2218 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002219 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002220
2221 # COND block (input: iter, output: cond_tens )
2222 self.ser.startBasicBlock(cond_block)
2223 self.ser.addInputTensor(iter)
2224 self.ser.addInputTensor(a)
2225 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002226 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
2227 cond_tens = self.ser.addOutput([], DType.BOOL)
2228 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002229
2230 # BODY block (input: a, acc, iter, output: a, acc, iter)
2231 # Note that local intermediate tensors need to be declared here for the outputs
2232 self.ser.startBasicBlock(body_block)
2233 self.ser.addInputTensor(iter)
2234 self.ser.addInputTensor(a)
2235 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002236 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
2237 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2238 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002239 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2240 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2241 self.ser.addOutputTensor(iter_body_out)
2242 self.ser.addOutputTensor(a)
2243 self.ser.addOutputTensor(acc_body_out)
2244
2245 return acc_out
2246
Kevin Cheng550ccc52021-03-03 11:21:43 -08002247 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01002248 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08002249 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002250
2251 try:
2252 op = self.TOSA_OP_LIST[opName]
2253 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002254 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002255
2256 # Initialize a new random number generator
2257 self.rng = np.random.default_rng(self.random_seed)
2258
Kevin Cheng550ccc52021-03-03 11:21:43 -08002259 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002260
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002261 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2262 default_test_rank_range = range(1, 5)
Matthew Haddone86fd342021-09-07 16:12:21 +01002263 if not shapeFilter:
2264 shapeFilter = [None]
2265
2266 # Generate the lists of arguments
2267 rmin, rmax = op["rank"]
2268 if rankFilter is not None:
2269 cleanRankFilter = []
2270 # Ensure rankFilter values are allowed by operator
2271 for rank in rankFilter:
2272 if rank >= rmin and rank <= rmax:
2273 cleanRankFilter.append(rank)
2274 rankFilter = cleanRankFilter
2275 elif rankFilter is None and shapeFilter[0] is None:
2276 cleanRankFilter = []
2277 # Ensure default behaviour is bounded by default range or by operator, whichever is smaller.
2278 rankRange = range(rmin, rmax + 1)
2279 for rank in rankRange:
2280 if rank >= min(default_test_rank_range) and rank <= max(default_test_rank_range):
2281 cleanRankFilter.append(rank)
2282 rankFilter = cleanRankFilter
2283 else:
2284 rankFilter = range(rmin, rmax + 1)
2285
2286 dtypes = op["types"]
2287 if dtypeFilter is not None:
2288 cleanDtypeFilter = []
2289 # Ensure filtered dtypes are allowed by operator
2290 for dtype in dtypeFilter:
2291 if dtype in dtypes:
2292 cleanDtypeFilter.append(dtype)
2293 dtypeFilter = cleanDtypeFilter
2294 else:
2295 dtypeFilter = dtypes
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002296
Eric Kunzee5e26762020-10-13 16:11:07 -07002297 # Test list consists of a tuple of:
2298 # (opName, testNameStr, dtype, shapeList, argumentsList)
2299 testList = []
2300
Matthew Haddon74567092021-07-16 15:38:20 +01002301 # Positive test loop
2302 if testType in ['positive', 'both']:
Matthew Haddone86fd342021-09-07 16:12:21 +01002303 for r in rankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07002304 if opName.startswith("conv3d"):
2305 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddone86fd342021-09-07 16:12:21 +01002306 for t in dtypeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002307 # Create the placeholder and const tensors
2308 for shape in shapeFilter:
2309 # A None shape chooses a random shape of a given rank
Eric Kunzee5e26762020-10-13 16:11:07 -07002310
Matthew Haddon74567092021-07-16 15:38:20 +01002311 # Filter out by rank
2312 if shape is not None and len(shape) != r:
2313 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002314 self.setTargetShape(shape)
2315 shapeList = tgen_fcn(self, op, r)
Eric Kunzee5e26762020-10-13 16:11:07 -07002316
Matthew Haddon74567092021-07-16 15:38:20 +01002317 shapeStr = self.shapeStr(shapeList[0])
2318 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002319
Matthew Haddon74567092021-07-16 15:38:20 +01002320 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2321 argList = []
2322 if agen_fcn:
2323 argList = agen_fcn(self, opName, shapeList, t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002324 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002325 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002326
Matthew Haddon74567092021-07-16 15:38:20 +01002327 for argStr, args in argList:
2328 if argStr:
2329 testStr = "{}_{}_{}_{}".format(
2330 opName, shapeStr, typeStr, argStr
2331 )
2332 else:
2333 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
2334
Matthew Haddone86fd342021-09-07 16:12:21 +01002335 testList.append((opName, testStr, t, None, shapeList, args))
Matthew Haddon74567092021-07-16 15:38:20 +01002336
Matthew Haddonb724efc2021-08-25 16:40:29 +01002337 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2338 if "invalid_test_validators" in op:
2339 invalid_test_validators = op["invalid_test_validators"]
2340 clean_testList = []
2341 for test in testList:
2342 for validator_fcn in invalid_test_validators:
2343 remove_test = False
Matthew Haddone86fd342021-09-07 16:12:21 +01002344 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
Matthew Haddonb724efc2021-08-25 16:40:29 +01002345 remove_test = True
2346 if not remove_test:
2347 clean_testList.append(test)
2348 testList = clean_testList
2349
Matthew Haddone86fd342021-09-07 16:12:21 +01002350 # Store the original filters so they can be reused if required
2351 base_rankFilter = rankFilter
2352 base_dtypeFilter = dtypeFilter
2353 base_shapeFilter = shapeFilter
Matthew Haddon74567092021-07-16 15:38:20 +01002354 # Reset RNG so both positive and negative tests are reproducible
2355 self.resetRNG()
Matthew Haddone86fd342021-09-07 16:12:21 +01002356
Matthew Haddon74567092021-07-16 15:38:20 +01002357 # Negative test loop
Matthew Haddone86fd342021-09-07 16:12:21 +01002358 if testType in ['negative', 'both'] and "error_if_validators" in op:
2359 error_if_validators = op["error_if_validators"]
2360 for validator in error_if_validators:
2361 validator_info = validator()
2362 error_name = validator_info['error_name']
2363 error_arguments = validator_info['param_reqs']
2364
2365 #Set parameters as required
2366 if error_arguments['rank'] != None:
2367 rmin, rmax = error_arguments['rank']
2368 rankFilter = range(rmin, rmax + 1)
2369 else:
2370 rankFilter = base_rankFilter
2371 if error_arguments['dtype'] != None:
2372 dtypeFilter = error_arguments['dtype']
2373 else:
2374 dtypeFilter = base_dtypeFilter
2375 if error_arguments['shape'] != None:
2376 shapes = error_arguments['shape']
2377 else:
2378 shapes = base_shapeFilter[:2] # Reduce number of shapes to keep test numbers small
2379
2380 for r in range(rmin, rmax + 1):
2381 for t in dtypeFilter:
2382 # Create the placeholder and const tensors
2383 for shape in shapes:
2384 # A None shape chooses a random shape of a given rank
2385 # Filter out by rank
2386 if shape is not None and len(shape) != r:
2387 continue
2388 self.setTargetShape(shape)
2389 shapeList = tgen_fcn(self, op, r, error_name)
2390 shapeStr = self.shapeStr(shapeList[0])
2391 typeStr = self.typeStr(t)
2392 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2393 argList = []
2394 if agen_fcn:
2395 argList = agen_fcn(self, opName, shapeList, t, error_name)
2396 else:
2397 argList = [("", [])]
2398 for argStr, args in argList:
2399 if argStr:
2400 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2401 opName, error_name, shapeStr, typeStr, argStr
2402 )
2403 else:
2404 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
2405 testList.append((opName, testStr, t, error_name, shapeList, args))
Eric Kunzee5e26762020-10-13 16:11:07 -07002406
2407 return testList
2408
Matthew Haddone86fd342021-09-07 16:12:21 +01002409
2410 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07002411 try:
2412 op = self.TOSA_OP_LIST[opName]
2413 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002414 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002415
2416 # Create a serializer
2417 self.createSerializer(opName, testStr)
2418
Kevin Cheng550ccc52021-03-03 11:21:43 -08002419 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002420 if "error_if_validators" in op:
2421 error_if_validators = op["error_if_validators"]
2422 else:
2423 error_if_validators = None
2424
Kevin Cheng550ccc52021-03-03 11:21:43 -08002425 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002426 num_operands = pCount + cCount
2427
2428 if isinstance(dtype_or_dtypeList, list):
2429 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002430 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002431 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002432 else:
2433 dtypeList = [dtype_or_dtypeList] * (num_operands)
2434
Kevin Cheng93a16282021-08-31 16:14:03 -07002435 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002436 assert (
2437 len(shapeList) == num_operands
2438 ), "shapeList length {} must match number of operands {}".format(
2439 len(shapeList), num_operands
2440 )
2441 assert (
2442 len(dtypeList) == num_operands
2443 ), "dtypeList length {} must match number of operands {}".format(
2444 len(dtypeList), num_operands
2445 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002446
2447 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002448 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002449 except KeyError:
2450 qgen = None
2451
2452 # Build the random tensor operands and the test
2453 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002454
Jeremy Johnsonef509a42021-09-07 13:59:47 +01002455 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32:
2456 # Make sure the operation does not cause value saturation - where
2457 # the number wraps due to limited number of bits to store the answer
2458 assert (
2459 pCount == 2 and cCount == 0
2460 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
2461
2462 placeholders = []
2463 add = (op["op"] == Op.ADD)
2464 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2465 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2466 if add:
2467 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
2468 else:
2469 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
2470
2471 # Work out the saturation limits
2472 max_i32 = (1 << 31)-1
2473 min_i32 = -(1 << 31)
2474 max_arr = np.full(shapeList[1], max_i32)
2475 min_arr = np.full(shapeList[1], min_i32)
2476
2477 # Find how much values exceed the maximum/minimums
2478 sat_max_arr = np.maximum(res_arr - max_arr, 0)
2479 sat_min_arr = np.minimum(res_arr - min_arr, 0)
2480
2481 if not add:
2482 # Swap saturation values and negate values as we need to perform opposite operations
2483 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
2484
2485 # Create new array of unsaturated values by clipping values as needed
2486 b_unsat_arr = b_arr
2487 if (sat_max_arr != 0).any():
2488 # Clip values that cause saturation
2489 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
2490 # Reduce axes in unsaturated tensor to match original tensor
2491 for axis, dim in enumerate(b_arr.shape):
2492 if dim != b_unsat_arr.shape[axis]:
2493 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2494 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
2495
2496 if (sat_min_arr != 0).any():
2497 # Clip values that cause saturation
2498 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
2499 # Reduce axes in unsaturated tensor to match original tensor
2500 for axis, dim in enumerate(b_arr.shape):
2501 if dim != b_unsat_arr.shape[axis]:
2502 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2503 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
2504
2505 placeholders.append(
2506 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2507 )
2508 placeholders.append(
2509 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
2510 )
2511
2512 tens.extend(placeholders)
2513 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
2514 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002515 assert (
2516 pCount == 2 and cCount == 0
2517 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08002518
2519 placeholders = []
2520 for idx, shape in enumerate(shapeList[:]):
2521 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07002522 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002523 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002524 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002525 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002526 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002527 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
2528 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002529 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002530 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002531 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07002532 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002533
2534 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01002535 elif op["op"] == Op.SELECT:
2536 # Set datatype of condition tensor to boolean
2537 dtypeList[0] = DType.BOOL
2538 tens.extend(
2539 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2540 )
2541 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddon459443c2021-08-23 16:43:13 +01002542 elif op["op"] == Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002543 assert (
2544 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01002545 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002546
2547 placeholders = []
2548
Matthew Haddon459443c2021-08-23 16:43:13 +01002549 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002550 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07002551 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002552 while True:
2553 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2554 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2555
2556 if (divisor_arr == 0).any():
2557 continue
2558
Kevin Cheng47315e12021-05-13 17:41:28 -07002559 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002560 continue
2561
2562 break
2563
2564 placeholders.append(
2565 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
2566 )
2567 placeholders.append(
2568 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
2569 )
2570
2571 tens.extend(placeholders)
2572 elif op["op"] == Op.MUL:
2573 assert (
2574 pCount == 2 and cCount == 0
2575 ), "Op.MUL must have 2 placeholders, 0 consts"
2576
2577 if dtypeList[0] == DType.FLOAT:
2578 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
2579 else:
2580 placeholders = []
2581
2582 # Make sure multiply result in int32 range
2583 shift = testArgs[0]
2584 if dtypeList[0] == DType.INT8:
2585 num_bits = 8
2586 elif dtypeList[0] == DType.INT16:
2587 num_bits = 16
2588 elif dtypeList[0] == DType.INT32:
2589 num_bits = 32
2590 else:
2591 raise Exception("OpMul: invalid input dtype")
2592
2593 for idx, shape in enumerate(shapeList[:]):
2594 low = -(2 ** (num_bits - 1))
2595 high = (2 ** (num_bits - 1)) - 1
2596
2597 a_arr = np.int32(
2598 self.rng.integers(low=low, high=high, size=shapeList[0])
2599 )
2600 b_arr = np.int32(
2601 self.rng.integers(low=low, high=high, size=shapeList[1])
2602 )
2603
2604 i = 0
2605 while True:
2606
2607 a_arr_64 = a_arr.astype(np.int64)
2608 b_arr_64 = b_arr.astype(np.int64)
2609
2610 if shift > 0:
2611 rounding = 1 << (shift - 1)
2612 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
2613 else:
2614 result_arr = a_arr_64 * b_arr_64
2615
2616 if (result_arr > -(2 ** 31)).all() and (
2617 result_arr <= ((2 ** 31) - 1)
2618 ).all():
2619 break
2620
2621 i = i + 1
2622 a_arr = a_arr // 2
2623 b_arr = b_arr // 2
2624
2625 placeholders.append(
2626 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2627 )
2628 placeholders.append(
2629 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
2630 )
2631
2632 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01002633 elif op["op"] == Op.CONCAT:
2634 count = len(shapeList) - self.args.num_const_inputs_concat
2635 if count < 1:
2636 count = 1
2637 if self.args.num_const_inputs_concat == 0:
2638 count = len(shapeList)
2639
2640 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
2641 tens.extend(
2642 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
2643 )
2644 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002645 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002646 tens.extend(
2647 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2648 )
2649 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002650
2651 if qgen is not None:
Les Bell30e46802021-07-23 09:43:31 +01002652 qinfo = qgen(self, op, dtype_or_dtypeList)
Eric Kunzee5e26762020-10-13 16:11:07 -07002653 else:
2654 qinfo = None
2655
2656 try:
Matthew Haddone86fd342021-09-07 16:12:21 +01002657 if error_if_validators is None:
2658 if qinfo is not None:
2659 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
2660 else:
2661 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07002662 else:
Matthew Haddone86fd342021-09-07 16:12:21 +01002663 if qinfo is not None:
2664 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo, error_if_validators, error_name)
2665 else:
2666 resultName = build_fcn(self, op["op"], *tens, *testArgs, error_if_validators, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002667 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002668 print(
2669 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
2670 build_fcn, tens, testArgs
2671 )
2672 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002673 raise e
2674
Matthew Haddone86fd342021-09-07 16:12:21 +01002675 if resultName is None:
2676 print("Invalid ERROR_IF tests created")
2677
Eric Kunzee5e26762020-10-13 16:11:07 -07002678 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08002679 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07002680
2681 def createDynamicOpLists(self):
2682
2683 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002684 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002685
Kevin Cheng1533b852021-09-01 12:51:58 -07002686 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002687 testName = "conv2d_{}x{}".format(k[0], k[1])
2688 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2689 self.TOSA_OP_LIST[testName]["filter"] = k
2690 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002691
Kevin Cheng550ccc52021-03-03 11:21:43 -08002692 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2693 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2694 "depthwise_conv2d_TEMPLATE"
2695 ].copy()
2696 self.TOSA_OP_LIST[testName]["filter"] = k
2697 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002698
Kevin Cheng550ccc52021-03-03 11:21:43 -08002699 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2700 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2701 "transpose_conv2d_TEMPLATE"
2702 ].copy()
2703 self.TOSA_OP_LIST[testName]["filter"] = k
2704 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002705
Kevin Cheng1533b852021-09-01 12:51:58 -07002706 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2707 for k in KERNELS_3D:
2708 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2709 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2710 self.TOSA_OP_LIST[testName]["filter"] = k
2711 self.TOSA_OP_LIST[testName]["template"] = False
2712
Eric Kunzee5e26762020-10-13 16:11:07 -07002713 # Delete any templates after having created any dynamic ops
2714 # This is a two-pass operation because it's bad practice to delete
2715 # keys from dictionaries while iterating
2716 keyList = []
2717 for k in self.TOSA_OP_LIST:
2718 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002719 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07002720 keyList.append(k)
2721 continue
2722 except KeyError:
2723 pass
2724
2725 for k in keyList:
2726 del self.TOSA_OP_LIST[k]
2727
2728 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002729 """Fill in default fields for ops if they aren't already specified.
2730 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002731 for op in self.TOSA_OP_LIST:
2732
2733 # Required fields
2734 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002735 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002736 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002737 raise Exception(
2738 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2739 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002740
2741 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002742 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002743 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002744 raise Exception(
2745 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2746 op
2747 )
2748 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002749
2750 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002751 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002752 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002753 raise Exception(
2754 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2755 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002756
2757 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002758 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002759 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002760 raise Exception(
2761 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2762 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002763
2764 # Put in default rank range, if missing
2765 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002766 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002767 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002768 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002769
2770 # Tensor operator list
2771 # 'op': op name
2772 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002773 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2774 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002775 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2776 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002777 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002778
Kevin Cheng550ccc52021-03-03 11:21:43 -08002779 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2780 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002781
Kevin Cheng550ccc52021-03-03 11:21:43 -08002782 TYPE_BOOL = [DType.BOOL]
2783 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2784 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2785 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002786
Kevin Cheng550ccc52021-03-03 11:21:43 -08002787 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002788
Kevin Cheng1533b852021-09-01 12:51:58 -07002789 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002790 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002791 [DType.INT8, DType.INT8, DType.INT32],
2792 [DType.INT16, DType.INT8, DType.INT48],
2793 DType.FLOAT,
2794 ]
2795
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002796 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002797
2798 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002799 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002800 "argmax": {
2801 "op": Op.ARGMAX,
2802 "operands": (1, 0),
2803 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2804 "types": TYPE_NARROW_INT_FP,
2805 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002806 "avg_pool2d": {
2807 "op": Op.AVG_POOL2D,
2808 "operands": (1, 0),
2809 "rank": (4, 4),
2810 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2811 "qgen": TosaQuantGen.qgUnary,
2812 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002813 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08002814 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002815 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002816 "conv2d_TEMPLATE": {
2817 "op": Op.CONV2D,
2818 "operands": (1, 2),
2819 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01002820 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002821 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002822 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002823 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002824 "template": True,
2825 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002826 # Templated operator. Filled in by createDynamicOpLists
2827 "conv3d_TEMPLATE": {
2828 "op": Op.CONV3D,
2829 "operands": (1, 2),
2830 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01002831 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07002832 "qgen": TosaQuantGen.qgConv,
2833 "types": TYPE_CONV,
2834 "template": True,
2835 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002836 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002837 "depthwise_conv2d_TEMPLATE": {
2838 "op": Op.DEPTHWISE_CONV2D,
2839 "operands": (1, 2),
2840 "filter": [1, 1],
2841 "rank": (4, 4),
2842 "build_fcn": (
2843 build_depthwise_conv2d,
2844 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01002845 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002846 ),
2847 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002848 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002849 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002850 "template": True,
2851 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002852 "fully_connected": {
2853 "op": Op.FULLY_CONNECTED,
2854 "operands": (1, 2),
2855 "rank": (2, 2),
2856 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
2857 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002858 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08002859 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002860 "matmul": {
2861 "op": Op.MATMUL,
2862 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002863 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08002864 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
2865 "qgen": TosaQuantGen.qgMatmul,
2866 "types": TYPE_NARROW_INT_FP,
2867 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002868 "max_pool2d": {
2869 "op": Op.MAX_POOL2D,
2870 "operands": (1, 0),
2871 "rank": (4, 4),
2872 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2873 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002874 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08002875 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002876 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002877 "transpose_conv2d_TEMPLATE": {
2878 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002879 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002880 "rank": (4, 4),
2881 "build_fcn": (
2882 build_transpose_conv2d,
2883 TosaTensorGen.tgTransposeConv2D,
2884 TosaArgGen.agTransposeConv2D,
2885 ),
2886 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002887 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002888 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002889 "template": True,
2890 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002891 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002892 "clamp": {
2893 "op": Op.CLAMP,
2894 "operands": (1, 0),
2895 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2896 "types": TYPE_NARROW_INT_FP,
2897 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002898 "sigmoid": {
2899 "op": Op.SIGMOID,
2900 "operands": (1, 0),
2901 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2902 "types": TYPE_FP,
2903 },
2904 "tanh": {
2905 "op": Op.TANH,
2906 "operands": (1, 0),
2907 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2908 "types": TYPE_FP,
2909 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002910 # Elementwise Binary Operators
2911 "add": {
2912 "op": Op.ADD,
2913 "operands": (2, 0),
2914 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2915 "types": TYPE_FI32,
2916 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002917 "arithmetic_right_shift": {
2918 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2919 "operands": (2, 0),
2920 "build_fcn": (
2921 build_arithmetic_right_shift,
2922 TosaTensorGen.tgBroadcastFuzz,
2923 TosaArgGen.agArithmeticRightShift,
2924 ),
2925 "types": TYPE_INT,
2926 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002927 "bitwise_and": {
2928 "op": Op.BITWISE_AND,
2929 "operands": (2, 0),
2930 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2931 "types": TYPE_INT,
2932 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002933 "bitwise_or": {
2934 "op": Op.BITWISE_OR,
2935 "operands": (2, 0),
2936 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2937 "types": TYPE_INT,
2938 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002939 "bitwise_xor": {
2940 "op": Op.BITWISE_XOR,
2941 "operands": (2, 0),
2942 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2943 "types": TYPE_INT,
2944 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002945 "intdiv": {
2946 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002947 "operands": (2, 0),
2948 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2949 "types": [DType.INT32],
2950 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002951 "logical_and": {
2952 "op": Op.LOGICAL_AND,
2953 "operands": (2, 0),
2954 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2955 "types": TYPE_BOOL,
2956 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002957 "logical_left_shift": {
2958 "op": Op.LOGICAL_LEFT_SHIFT,
2959 "operands": (2, 0),
2960 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2961 "types": TYPE_INT,
2962 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002963 "logical_right_shift": {
2964 "op": Op.LOGICAL_RIGHT_SHIFT,
2965 "operands": (2, 0),
2966 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2967 "types": TYPE_INT,
2968 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002969 "logical_or": {
2970 "op": Op.LOGICAL_OR,
2971 "operands": (2, 0),
2972 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2973 "types": TYPE_BOOL,
2974 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002975 "logical_xor": {
2976 "op": Op.LOGICAL_XOR,
2977 "operands": (2, 0),
2978 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2979 "types": TYPE_BOOL,
2980 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002981 "maximum": {
2982 "op": Op.MAXIMUM,
2983 "operands": (2, 0),
2984 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2985 "types": TYPE_FI32,
2986 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002987 "minimum": {
2988 "op": Op.MINIMUM,
2989 "operands": (2, 0),
2990 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2991 "types": TYPE_FI32,
2992 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002993 "mul": {
2994 "op": Op.MUL,
2995 "operands": (2, 0),
2996 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2997 "types": TYPE_INT_FP,
2998 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002999 "pow": {
3000 "op": Op.POW,
3001 "operands": (2, 0),
3002 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
3003 "types": TYPE_FP,
3004 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003005 "sub": {
3006 "op": Op.SUB,
3007 "operands": (2, 0),
3008 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3009 "types": TYPE_FI32,
3010 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003011 "table": {
3012 "op": Op.TABLE,
3013 # Use the automatic generation functions to create the input array
3014 # but create the table tensor in the build function, as it may be
3015 # a different type from the input
3016 "operands": (1, 0),
3017 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003018 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08003019 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003020 # Elementwise Unary operators
3021 "abs": {
3022 "op": Op.ABS,
3023 "operands": (1, 0),
3024 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3025 "types": TYPE_FI32,
3026 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003027 "bitwise_not": {
3028 "op": Op.BITWISE_NOT,
3029 "operands": (1, 0),
3030 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3031 "types": TYPE_INT,
3032 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003033 "ceil": {
3034 "op": Op.CEIL,
3035 "operands": (1, 0),
3036 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3037 "types": TYPE_FP,
3038 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003039 "clz": {
3040 "op": Op.CLZ,
3041 "operands": (1, 0),
3042 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3043 "types": [DType.INT32],
3044 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003045 "exp": {
3046 "op": Op.EXP,
3047 "operands": (1, 0),
3048 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3049 "types": TYPE_FP,
3050 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003051 "floor": {
3052 "op": Op.FLOOR,
3053 "operands": (1, 0),
3054 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3055 "types": TYPE_FP,
3056 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003057 "log": {
3058 "op": Op.LOG,
3059 "operands": (1, 0),
3060 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3061 "types": TYPE_FP,
3062 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003063 "logical_not": {
3064 "op": Op.LOGICAL_NOT,
3065 "operands": (1, 0),
3066 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3067 "types": TYPE_BOOL,
3068 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003069 "negate": {
3070 "op": Op.NEGATE,
3071 "operands": (1, 0),
3072 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3073 "qgen": TosaQuantGen.qgUnary,
3074 "types": TYPE_INT_FP,
3075 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003076 "reciprocal": {
3077 "op": Op.RECIPROCAL,
3078 "operands": (1, 0),
3079 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3080 "types": TYPE_FP,
3081 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003082 "rsqrt": {
3083 "op": Op.RSQRT,
3084 "operands": (1, 0),
3085 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3086 "types": TYPE_FP,
3087 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003088 # Elementwise Ternary operators
3089 "select": {
3090 "op": Op.SELECT,
3091 "operands": (3, 0),
3092 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
3093 "types": TYPE_FIB,
3094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003095 # Comparison operators
3096 "equal": {
3097 "op": Op.EQUAL,
3098 "operands": (2, 0),
3099 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3100 "types": TYPE_FI32,
3101 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003102 "greater_equal": {
3103 "op": Op.GREATER_EQUAL,
3104 "operands": (2, 0),
3105 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3106 "types": TYPE_FI32,
3107 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003108 "greater": {
3109 "op": Op.GREATER,
3110 "operands": (2, 0),
3111 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3112 "types": TYPE_FI32,
3113 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003114 # Reduction operators
3115 "reduce_all": {
3116 "op": Op.REDUCE_ALL,
3117 "operands": (1, 0),
3118 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3119 "types": TYPE_BOOL,
3120 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003121 "reduce_any": {
3122 "op": Op.REDUCE_ANY,
3123 "operands": (1, 0),
3124 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3125 "types": TYPE_BOOL,
3126 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003127 "reduce_max": {
3128 "op": Op.REDUCE_MAX,
3129 "operands": (1, 0),
3130 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3131 "types": TYPE_INT_FP,
3132 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003133 "reduce_min": {
3134 "op": Op.REDUCE_MAX,
3135 "operands": (1, 0),
3136 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3137 "types": TYPE_INT_FP,
3138 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003139 "reduce_product": {
3140 "op": Op.REDUCE_PRODUCT,
3141 "operands": (1, 0),
3142 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3143 "types": TYPE_FP,
3144 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003145 "reduce_sum": {
3146 "op": Op.REDUCE_SUM,
3147 "operands": (1, 0),
3148 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3149 "types": TYPE_FI32,
3150 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003151 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003152 "concat": {
3153 "op": Op.CONCAT,
3154 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01003155 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003156 "types": TYPE_FIB,
3157 },
3158 "pad": {
3159 "op": Op.PAD,
3160 "operands": (1, 0),
3161 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
3162 "qgen": TosaQuantGen.qgPad,
3163 "types": TYPE_FIB,
3164 },
3165 "reshape": {
3166 "op": Op.RESHAPE,
3167 "operands": (1, 0),
3168 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
3169 "types": TYPE_FIB,
3170 },
3171 "reverse": {
3172 "op": Op.REVERSE,
3173 "operands": (1, 0),
3174 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3175 "types": TYPE_FIB,
3176 },
3177 "slice": {
3178 "op": Op.SLICE,
3179 "operands": (1, 0),
3180 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
3181 "types": TYPE_FIB,
3182 },
3183 "tile": {
3184 "op": Op.TILE,
3185 "operands": (1, 0),
3186 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
3187 "types": TYPE_FIB,
3188 },
3189 "transpose": {
3190 "op": Op.TRANSPOSE,
3191 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003192 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003193 "build_fcn": (
3194 build_transpose,
3195 TosaTensorGen.tgBasic,
3196 TosaArgGen.agTranspose,
3197 ),
3198 "types": TYPE_FIB,
3199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003200 # Data nodes
3201 "const": {
3202 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003203 "operands": (0, 1),
3204 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08003205 "types": TYPE_FIB,
3206 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003207 "identity": {
3208 "op": Op.IDENTITY,
3209 "operands": (1, 0),
3210 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3211 "types": TYPE_FIB,
3212 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003213 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003214 "gather": {
3215 "op": Op.GATHER,
3216 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3217 "operands": (1, 0),
3218 "rank": (3, 3),
3219 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
3220 "types": TYPE_INT_FP,
3221 },
3222 "scatter": {
3223 "op": Op.SCATTER,
3224 # Only specify 'values_in' tensor here.
3225 #'indices' and 'input' are generated in op building stage
3226 "operands": (2, 0),
3227 "rank": (3, 3),
3228 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
3229 "types": TYPE_INT_FP,
3230 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003231 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003232 "resize": {
3233 "op": Op.RESIZE,
3234 "operands": (1, 0),
3235 "rank": (4, 4),
3236 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
3237 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01003238 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
3239 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
3240 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
3241 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003242 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003243 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003244 "cast": {
3245 "op": Op.CAST,
3246 "operands": (1, 0),
3247 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
3248 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
3249 },
3250 "rescale": {
3251 "op": Op.RESCALE,
3252 "operands": (1, 0),
3253 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003254 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003255 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003256 # Custom
3257 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003258 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003259 # Two varients of cond_if, one that generates one of two constant tensors (no
3260 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3261 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003262 "cond_if_const": {
3263 "op": Op.COND_IF,
3264 "operands": (0, 2),
3265 "build_fcn": (
3266 build_cond_if_const,
3267 TosaTensorGen.tgBasic,
3268 TosaArgGen.agCondIf,
3269 ),
3270 "types": [DType.BOOL],
3271 },
3272 "cond_if_binary": {
3273 "op": Op.COND_IF,
3274 "operands": (2, 0),
3275 "build_fcn": (
3276 build_cond_if_binary,
3277 TosaTensorGen.tgBasic,
3278 TosaArgGen.agCondIf,
3279 ),
3280 "types": TYPE_FI32,
3281 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003282 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003283 "while_loop": {
3284 "op": Op.WHILE_LOOP,
3285 "operands": (0, 1),
3286 "build_fcn": (
3287 build_while_loop,
3288 TosaTensorGen.tgBasic,
3289 TosaArgGen.agWhileLoop,
3290 ),
3291 "types": [DType.INT32],
3292 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003293 }
3294
Kevin Cheng550ccc52021-03-03 11:21:43 -08003295
Eric Kunzee5e26762020-10-13 16:11:07 -07003296class OutputShaper:
3297 # Methods in this class compute the expected output shape and datatype
3298 # for common classes of operations
3299 def __init__(self):
3300 pass
3301
3302 # These methods return arguments that can be used for
3303 # creating a new output tensor
3304 @staticmethod
3305 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003306 assert len(a.shape) == len(b.shape)
3307 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003308
3309 shape = []
3310 for i in range(len(a.shape)):
3311 if a.shape[i] == 1:
3312 shape.append(b.shape[i])
3313 else:
3314 shape.append(a.shape[i])
3315
Kevin Cheng550ccc52021-03-03 11:21:43 -08003316 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003317
3318 @staticmethod
3319 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003320 assert len(a.shape) == len(b.shape)
3321 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003322
3323 shape = []
3324 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003325 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003326 shape.append(a.shape[i])
3327
Kevin Cheng550ccc52021-03-03 11:21:43 -08003328 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003329
3330 @staticmethod
3331 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003332 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003333
3334 @staticmethod
3335 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003336 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
3337 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003338
3339 shape = []
3340 for i in range(len(a.shape)):
3341 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3342
Kevin Cheng550ccc52021-03-03 11:21:43 -08003343 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003344
3345 @staticmethod
3346 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003347 assert len(a.shape) == len(b.shape)
3348 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003349
3350 # Do broadcast
3351 shape = []
3352 for i in range(len(a.shape)):
3353 if a.shape[i] == 1:
3354 shape.append(b.shape[i])
3355 else:
3356 shape.append(a.shape[i])
3357
3358 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08003359 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07003360
3361 @staticmethod
3362 def reduceOp(ser, a, axis):
3363
3364 shape = a.shape.copy()
3365
3366 shape[axis] = 1
3367
Kevin Cheng550ccc52021-03-03 11:21:43 -08003368 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003369
3370 @staticmethod
3371 def argmaxOp(ser, a, axis):
3372 shape = a.shape.copy()
3373 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003374 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003375
3376 @staticmethod
3377 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
3378
3379 # IFM: NHWC
3380 # Filter: OHWI
3381 # OFM: NHWC
3382
3383 if len(padding) == 2:
3384 # Expand padding to 4 parameters in the case of transpose_conv2d
3385 # From H,W to T,B,L,R
3386 padding = [padding[0], padding[0], padding[1], padding[1]]
3387
Kevin Cheng550ccc52021-03-03 11:21:43 -08003388 h = (
3389 ifm.shape[1]
3390 - filter.shape[1]
3391 - (filter.shape[1] - 1) * (dilations[0] - 1)
3392 + padding[0]
3393 + padding[1]
3394 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003395
Kevin Cheng550ccc52021-03-03 11:21:43 -08003396 w = (
3397 ifm.shape[2]
3398 - filter.shape[2]
3399 - (filter.shape[2] - 1) * (dilations[1] - 1)
3400 + padding[2]
3401 + padding[3]
3402 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003403
Eric Kunzee5e26762020-10-13 16:11:07 -07003404 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3405
Kevin Cheng3a478572021-01-22 17:21:02 -08003406 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003407 out_dtype = DType.INT32
3408 elif ifm.dtype == DType.INT16:
3409 out_dtype = DType.INT48
3410 elif ifm.dtype == DType.FLOAT:
3411 out_dtype = DType.FLOAT
3412 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003413 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003414
Kevin Cheng550ccc52021-03-03 11:21:43 -08003415 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003416
3417 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07003418 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
3419
3420 # IFM: NDHWC
3421 # Filter: ODHWI
3422 # OFM: NDHWC
3423
3424 d = (
3425 ifm.shape[1]
3426 - filter.shape[1]
3427 - (filter.shape[1] - 1) * (dilations[0] - 1)
3428 + padding[0]
3429 + padding[1]
3430 ) // strides[0] + 1
3431
3432 h = (
3433 ifm.shape[2]
3434 - filter.shape[2]
3435 - (filter.shape[2] - 1) * (dilations[1] - 1)
3436 + padding[2]
3437 + padding[3]
3438 ) // strides[1] + 1
3439
3440 w = (
3441 ifm.shape[3]
3442 - filter.shape[3]
3443 - (filter.shape[3] - 1) * (dilations[2] - 1)
3444 + padding[4]
3445 + padding[5]
3446 ) // strides[2] + 1
3447
3448 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
3449
3450 if ifm.dtype == DType.INT8:
3451 out_dtype = DType.INT32
3452 elif ifm.dtype == DType.INT16:
3453 out_dtype = DType.INT48
3454 elif ifm.dtype == DType.FLOAT:
3455 out_dtype = DType.FLOAT
3456 else:
3457 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
3458
3459 return ser.addOutput(ofm_shape, out_dtype)
3460
3461 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07003462 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
3463 # IFM: NHWC
3464 # Filter: HWCM
3465 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08003466 h = (
3467 ifm.shape[1]
3468 - filter.shape[0]
3469 - (filter.shape[0] - 1) * (dilations[0] - 1)
3470 + padding[0]
3471 + padding[1]
3472 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003473
Kevin Cheng550ccc52021-03-03 11:21:43 -08003474 w = (
3475 ifm.shape[2]
3476 - filter.shape[1]
3477 - (filter.shape[1] - 1) * (dilations[1] - 1)
3478 + padding[2]
3479 + padding[3]
3480 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003481
Eric Kunzee5e26762020-10-13 16:11:07 -07003482 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
3483
Kevin Cheng3a478572021-01-22 17:21:02 -08003484 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003485 out_dtype = DType.INT32
3486 elif ifm.dtype == DType.INT16:
3487 out_dtype = DType.INT48
3488 elif ifm.dtype == DType.FLOAT:
3489 out_dtype = DType.FLOAT
3490 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003491 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003492
Kevin Cheng550ccc52021-03-03 11:21:43 -08003493 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003494
3495 @staticmethod
3496 def pool2dOp(ser, ifm, kernel, stride, pad):
3497 # input: NHWC
3498 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
3499 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
3500
Eric Kunzee5e26762020-10-13 16:11:07 -07003501 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003502 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003503
3504 @staticmethod
3505 def fullyConnectedOp(ser, input, filter):
3506 # input: N, IC
3507 # filter: OC, IC
3508 # output: N, OC
3509
3510 output_shape = [input.shape[0], filter.shape[0]]
3511
Kevin Cheng3a478572021-01-22 17:21:02 -08003512 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003513 out_dtype = DType.INT32
3514 elif input.dtype == DType.INT16:
3515 out_dtype = DType.INT48
3516 elif input.dtype == DType.FLOAT:
3517 out_dtype = DType.FLOAT
3518 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003519 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003520
Kevin Cheng550ccc52021-03-03 11:21:43 -08003521 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003522
3523 @staticmethod
3524 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07003525 # a: N, H, C
3526 # b: N, C, W
3527 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07003528
Kevin Cheng2d60f002021-06-09 14:18:32 -07003529 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003530
Kevin Cheng3a478572021-01-22 17:21:02 -08003531 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003532 out_dtype = DType.INT32
3533 elif a.dtype == DType.INT16:
3534 out_dtype = DType.INT48
3535 elif a.dtype == DType.FLOAT:
3536 out_dtype = DType.FLOAT
3537 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003538 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003539
Kevin Cheng550ccc52021-03-03 11:21:43 -08003540 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003541
3542 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01003543 def concatOp(ser, axis, *a):
3544 input1 = a[0]
3545 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07003546
Matthew Haddon818ab902021-07-27 09:12:49 +01003547 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07003548
Matthew Haddon818ab902021-07-27 09:12:49 +01003549 output_shape[axis] = input1.shape[axis]
3550
3551 for tensor in remaining_inputs:
3552 output_shape[axis] += tensor.shape[axis]
3553
3554 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003555
3556 @staticmethod
3557 def padOp(ser, a, padding):
3558
3559 output_shape = a.shape.copy()
3560
3561 for i in range(len(output_shape)):
3562 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
3563
Kevin Cheng550ccc52021-03-03 11:21:43 -08003564 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003565
3566 @staticmethod
3567 def reshapeOp(ser, a, shape):
3568 output_shape = shape.copy()
3569
3570 totalElements = 1
3571 for i in a.shape:
3572 totalElements *= i
3573
3574 # If there are any -1 elements, figure out what that dimension must be
3575 totalOutputElements = 1
3576 for i in output_shape:
3577 if i != -1:
3578 totalOutputElements *= i
3579
3580 # And fill it in
3581 for i in range(len(output_shape)):
3582 if output_shape[i] == -1:
3583 output_shape[i] = totalElements // totalOutputElements
3584
Kevin Cheng550ccc52021-03-03 11:21:43 -08003585 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003586
3587 @staticmethod
3588 def sliceOp(ser, a, begin, size):
3589
3590 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003591 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003592
3593 @staticmethod
3594 def tileOp(ser, a, multiples):
3595
3596 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003597 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003598
3599 for i in range(len(output_shape)):
3600 output_shape[i] = a.shape[i] * multiples[i]
3601
Kevin Cheng550ccc52021-03-03 11:21:43 -08003602 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003603
3604 @staticmethod
3605 def transposeOp(ser, a, perms):
3606 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003607 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003608
3609 for i in range(len(output_shape)):
3610 output_shape[i] = a.shape[perms[i]]
3611
Kevin Cheng550ccc52021-03-03 11:21:43 -08003612 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003613
3614 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08003615 def gatherOp(ser, values, indices):
3616 assert len(values.shape) == 3
3617 assert len(indices.shape) == 2
3618 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07003619
Kevin Cheng77d0f762020-11-24 10:26:32 -08003620 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
3621
Kevin Cheng550ccc52021-03-03 11:21:43 -08003622 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003623
3624 @staticmethod
3625 def scatterOp(ser, values_in, indices, input):
3626 assert len(values_in.shape) == 3
3627 assert len(indices.shape) == 2
3628 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08003629 assert values_in.shape[0] == indices.shape[0] # N
3630 assert input.shape[1] == indices.shape[1] # W
3631 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08003632
3633 output_shape = values_in.shape
3634
Kevin Cheng550ccc52021-03-03 11:21:43 -08003635 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003636
3637 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003638 def tableOp(ser, input, table_dtype):
3639 # Same shape as the input, but dtype dependent on table dtype
3640 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
3641 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
3642 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003643
3644 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08003645 def resizeOp(
3646 ser,
3647 input,
3648 mode,
3649 stride,
3650 offset,
3651 shift,
3652 stride_fp,
3653 offset_fp,
3654 output_dims,
3655 input_dtype,
3656 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003657 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003658 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003659
3660 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
3661
Kevin Cheng550ccc52021-03-03 11:21:43 -08003662 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003663
3664 @staticmethod
3665 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003666 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003667
3668 @staticmethod
3669 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08003670 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003671 out_dtype = DType.INT32
3672 elif ifm.dtype == DType.INT16:
3673 out_dtype = DType.INT48
3674 elif ifm.dtype == DType.FLOAT:
3675 out_dtype = DType.FLOAT
3676 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003677 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003678
Kevin Cheng550ccc52021-03-03 11:21:43 -08003679 return ser.addOutput(output_shape, out_dtype)