blob: 07dc7e5c5f047b8f2b9e7a44fa29cf5e127831ed [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
Matthew Haddone86fd342021-09-07 16:12:21 +010045from tosa_error_if import ErrorIf
Eric Kunzee5e26762020-10-13 16:11:07 -070046
47# Convenience variables to the flatc-generated types that should be enums, but aren't
48DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080049Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070050ResizeMode = tosa.ResizeMode.ResizeMode()
51
52class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010059 def getQinfo(testGen, dtype):
60 if dtype == DType.INT8:
61 return testGen.randInt(-128, 128)
62 if dtype == DType.UINT8:
63 return testGen.randInt(0, 256)
64 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070065
66 @staticmethod
67 def qgUnary(testGen, op, dtype):
68 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070069 qinfo.UnaryQuantInfo(
70 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
71 )
Eric Kunzee5e26762020-10-13 16:11:07 -070072 return qinfo
73
74 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010075 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070076 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010077 if isinstance(dtype_or_dtypeList, list):
78 # a list of [input, weights, accumulator] dtypes
79 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070080 else:
Les Bell30e46802021-07-23 09:43:31 +010081 # an int, [input, weights, accumulator] dtypes are the same
82 dtypeList = [dtype_or_dtypeList] * 3
83 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
84 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
85 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070086 return qinfo
87
88 @staticmethod
89 def qgMatmul(testGen, op, dtype):
90 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070091 qinfo.MatMulQuantInfo(
92 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
93 )
Eric Kunzee5e26762020-10-13 16:11:07 -070094 return qinfo
95
96 @staticmethod
97 def qgPad(testGen, op, dtype):
98 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010099 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 return qinfo
101
102 @staticmethod
103 def computeMultiplierAndShift(scaleFp, scale32):
104 # Derived from computeMultiplierAndShiftTosaScale32
105 # Provide a floating-point scaling factor and the scale32 parameter
106 # to compute the multiplier and shift
107
108 if scale32:
109 scaleBits = 31
110 else:
111 scaleBits = 15
112
113 m, shift = math.frexp(scaleFp)
114
115 if scaleFp < 0.0:
116 m = -m
117
118 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800119 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 if multiplier == (1 << scaleBits):
122 multiplier = multiplier // 2
123 shift = shift + 1
124
125 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100126 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
127
128 # Adjust multiplier such that shift is in allowed value range.
129 if shift == 0:
130 multiplier = multiplier // 4
131 shift = shift + 2
132 elif shift == 1:
133 multiplier = multiplier // 2
134 shift = shift + 1
135 elif shift == 63:
136 multiplier = multiplier * 2
137 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700138
Kevin Cheng550ccc52021-03-03 11:21:43 -0800139 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100140 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700141
142 return multiplier, shift
143
144
Kevin Cheng550ccc52021-03-03 11:21:43 -0800145class TosaTensorGen:
146 """Tensor generators create a shape list for the placeholder and const tensor
147 data operands for the operator. The actual random data is generated separately for each test."""
148
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 def __init__(self):
150 pass
151
152 @staticmethod
153 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700155 shape = testGen.makeShape(rank)
156
157 shape_list = []
158 for i in range(pl + const):
159 shape_list.append(shape.copy())
160
161 return shape_list
162
163 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100164 def tgNHWC(testGen, opName, rank, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800165 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700166
Matthew Haddon848efb42021-09-09 12:30:53 +0100167 if error_name != ErrorIf.WrongRank:
168 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700169
170 shape = testGen.makeShape(rank)
171
172 # Constrict the batch size?
173 if testGen.args.max_batch_size:
174 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
175
176 shape_list = []
177 for i in range(pl + const):
178 shape_list.append(shape.copy())
179
180 return shape_list
181
182 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800183 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800184 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800185
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 assert pl == 2
187 assert const == 0
188 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800189
190 values_in_shape = testGen.makeShape(rank)
191
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100192 # ignore max batch size if target shape is set
193 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800194 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 W = testGen.randInt(
197 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
198 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100199 # Constrict W if one dimension is too large to keep tensor size reasonable
200 if max(values_in_shape) > 5000:
201 W = testGen.randInt(0, 16)
202
Kevin Cheng77d0f762020-11-24 10:26:32 -0800203 input_shape = [values_in_shape[0], W, values_in_shape[2]]
204
205 shape_list = []
206 shape_list.append(values_in_shape.copy())
207 shape_list.append(input_shape.copy())
208
209 return shape_list
210
211 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 def tgBroadcastFuzz(testGen, op, rank):
213 shape = testGen.makeShape(rank)
214
Kevin Cheng550ccc52021-03-03 11:21:43 -0800215 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700216
217 shape_list = []
218
219 # Choose one of the inputs to broadcast
220 bcast_idx = testGen.randInt(0, pl + const)
221 for i in range(pl + const):
222 shape_bcast = shape.copy()
223
224 # If the chosen input, pick a random index to broadcast
225 if i == bcast_idx:
226 fuzz_idx = testGen.randInt(0, rank)
227 shape_bcast[fuzz_idx] = 1
228
229 shape_list.append(shape_bcast)
230
231 return shape_list
232
233 @staticmethod
234 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800235 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700236
Kevin Cheng550ccc52021-03-03 11:21:43 -0800237 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
239 # IFM dimensions are NHWC
240 ifm_shape = testGen.makeShape(rank)
241
242 # Constrict the batch size?
243 if testGen.args.max_batch_size:
244 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
245
246 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800247 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700248
249 # Generate a random OFM depth
250 ofm_depth = testGen.makeShape(1)[0]
251
252 # The filter dimensions are OHWI
253 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
254
255 # The bias is OC
256 bias_shape = np.asarray([ofm_depth])
257
258 return [ifm_shape, filter_shape, bias_shape]
259
260 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -0700261 def tgConv3D(testGen, op, rank):
262 pl, const = op["operands"]
263
264 assert rank == 5
265
266 # IFM dimensions are NDHWC
267 ifm_shape = testGen.makeShape(rank)
268
269 # Constrict the batch size?
270 if testGen.args.max_batch_size:
271 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
272
273 # Get the filter depth/height/width from the operator parameters
274 filter_dhw = op["filter"]
275
276 # Generate a random OFM channel
277 ofm_channel = testGen.makeShape(1)[0]
278
279 # The filter dimensions are ODHWI
280 filter_shape = np.asarray(
281 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
282 )
283
284 # The bias is OC
285 bias_shape = np.asarray([ofm_channel])
286
287 return [ifm_shape, filter_shape, bias_shape]
288
289 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700290 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800291 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700292
Kevin Cheng550ccc52021-03-03 11:21:43 -0800293 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
295 # IFM dimensions are NHWC
296 ifm_shape = testGen.makeShape(rank)
297
298 # Constrict the batch size?
299 if testGen.args.max_batch_size:
300 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
301
302 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800303 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304
305 # Generate a random OFM depth
306 ofm_depth = testGen.makeShape(1)[0]
307
308 # The filter dimensions are OHWI
309 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
310
Kevin Cheng989cb052021-04-28 16:29:44 -0700311 # The bias is OC
312 bias_shape = np.asarray([ofm_depth])
313
314 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700315
316 @staticmethod
317 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800318 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
Kevin Cheng550ccc52021-03-03 11:21:43 -0800320 assert rank == 4
321 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700322
323 # IFM dimensions are NHWC
324 ifm_shape = testGen.makeShape(rank)
325
326 # Constrict the batch size?
327 if testGen.args.max_batch_size:
328 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
329
330 # Get the filter height/width from the operator parameters
331 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800332 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700333
334 # Generate a random OFM depth, but don't let it get too big because
335 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800336 filter_m = (
337 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
338 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700339
340 # The filter dimensions are HWCM
341 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
342
343 # The bias is M * C
344 bias_shape = np.asarray([ifm_shape[3] * filter_m])
345
346 return [ifm_shape, filter_shape, bias_shape]
347
348 @staticmethod
349 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800350 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700351
Kevin Cheng550ccc52021-03-03 11:21:43 -0800352 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700355 filter_oc = testGen.rng.integers(
356 low=testGen.args.tensor_shape_range[0],
357 high=testGen.args.tensor_shape_range[1],
358 size=1,
359 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700360 filter_shape = np.asarray([filter_oc, input_shape[1]])
361
362 bias_shape = np.asarray([filter_oc])
363
364 return [input_shape, filter_shape, bias_shape]
365
366 @staticmethod
367 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800368 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
Kevin Cheng2d60f002021-06-09 14:18:32 -0700370 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800371 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700372
373 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100374 # Get a random number for b_oc even if target shape is defined
375 b_oc = np.int32(
376 testGen.rng.integers(
377 low=testGen.args.tensor_shape_range[0],
378 high=testGen.args.tensor_shape_range[1],
379 size=1,
380 )
381 )[0]
382 # If N or H is large let b_oc be 1 to reduce output tensor size
383 if max(a_shape) > 1000:
384 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700385
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100386 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700387 return [a_shape, b_shape]
388
Matthew Haddon818ab902021-07-27 09:12:49 +0100389 @staticmethod
390 def tgConcat(testGen, opName, rank):
391 pl, const = opName["operands"]
392 shape = testGen.makeShape(rank)
393
394 # Create extra tensors to concat.
395 # Take into account value of pl when getting maximum number of concats
396 num_tensors = testGen.randInt(0, 4)
397 shape_list = []
398 for i in range(pl + const + num_tensors):
399 shape_list.append(shape.copy())
400
401 return shape_list
402
403 @staticmethod
404 def tgConcatConstInput(testGen, shapeList, axis):
405 # Split concat shape along axis to allow for multiple const inputs
406 # without making too many large tensors
407 shape = shapeList[0]
408 if len(shapeList) == 2 or shape[axis] < len(shapeList):
409 return shapeList
410
411 new_shapeList = [shape.copy()]
412 length_on_axis = shape[axis]
413 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700414 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100415 # Calculate split on axis and remaining value
416 split_shape_val = int(shape[axis] / 2)
417 remaining_length = remaining_length - split_shape_val
418
419 # Append new shape, and set remaining shape
420 shape[axis] = split_shape_val
421 new_shapeList.append(shape.copy())
422 shape[axis] = remaining_length
423 if i == len(shapeList) - 3:
424 new_shapeList.append(shape.copy())
425
426 return new_shapeList
427
428
Eric Kunzee5e26762020-10-13 16:11:07 -0700429class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800430 """Argument generators create exhaustive or random lists of attributes for operators that take
431 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
432 tuples where the descriptive_name is appended to the test name and the arglist is expanded
433 as arguments to the operator build function."""
434
Eric Kunzee5e26762020-10-13 16:11:07 -0700435 def __init__(self):
436 pass
437
438 @staticmethod
439 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800440 """A trivial argument generator for operators that don't take any
441 non-tensor arguments"""
442 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700443
444 @staticmethod
445 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800446 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700447 axes = []
448
449 shape = shapeList[0]
450
451 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100452 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700453 return axes
454
455 @staticmethod
Les Bell7aa69f42021-09-20 10:44:07 +0100456 def agConv(testGen, opName, shapeList, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700457 arg_list = []
458
459 ifm_shape = shapeList[0]
460 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100461 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
462 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700463
Les Bell7aa69f42021-09-20 10:44:07 +0100464 # Check the rank
465 rank = 5 if opName.startswith("conv3d") else 4
466 assert len(ifm_shape) == rank
467 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700468
Les Bell7aa69f42021-09-20 10:44:07 +0100469 # kernel rank omits batch and channels
470 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700471
Les Bell7aa69f42021-09-20 10:44:07 +0100472 # Generate comprehensive argument lists
473 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
474 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
475 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
476 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
477 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
478 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700479
Les Bell7aa69f42021-09-20 10:44:07 +0100480 # add some oversize argument values
481 if max(ifm_shape) < 64:
482 bigPadding = 9
483 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
484 bigStride = 8
485 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
486 bigDilation = 7
487 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100488
489 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100490 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
491 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
492 if sparsity < 13:
493 sparsity = 1
494 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
495 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100496 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100497 for s in sorted(list(strides)):
498 for p in sorted(list(paddings)):
499 for d in sorted(list(dilations)):
500 if (n % sparsity == 0
501 # padding must not exceed the kernel size ?
502 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
503 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
504 # the padded shape must exceed the kernel size
505 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
506 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
507 # the padded shape must exceed the dilation
508 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
509 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
510 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100511 arg_list.append(
512 (
513 "st{}_pad{}_dilat{}".format(
514 "".join([str(x) for x in s]),
515 "".join([str(x) for x in p]),
516 "".join([str(x) for x in d]),
517 ),
518 [s, p, d],
519 )
520 )
521 n += 1
522
Kevin Cheng1533b852021-09-01 12:51:58 -0700523 return arg_list
524
525 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700526 def agTransposeConv2D(testGen, opName, shapeList, dtype):
527 arg_list = []
528
529 ifm_shape = shapeList[0]
530 filter_shape = shapeList[1]
531
532 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800533 assert len(ifm_shape) == 4
534 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
Les Bell7aa69f42021-09-20 10:44:07 +0100536 # Generate comprehensive argument lists
537 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
538 paddings = {x for x in itertools.product(*([p_vals] * 2))}
539 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
540 strides = {x for x in itertools.product(*([s_vals] * 2))}
541 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
542 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700543
Les Bell7aa69f42021-09-20 10:44:07 +0100544 # add some oversize argument values
545 if max(ifm_shape) < 64:
546 bigPadding = 9
547 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
548 bigStride = 8
549 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
550 bigDilation = 7
551 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700552
Les Bell7aa69f42021-09-20 10:44:07 +0100553 # There are too many parameter combinations, so generate them sparsely
554 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
555 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
556 if sparsity < 13:
557 sparsity = 1
558 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
559 sparsity += 1
560 n = 0
561 for s in sorted(list(strides)):
562 for p in sorted(list(paddings)):
563 for d in sorted(list(dilations)):
564 if n % sparsity == 0:
565 # Determine the output shape
566 oh = (
567 ifm_shape[1]
568 - filter_shape[1]
569 - (filter_shape[1] - 1) * (d[0] - 1)
570 + 2 * p[0]
571 ) // s[0] + 1
572 ow = (
573 ifm_shape[2]
574 - filter_shape[2]
575 - (filter_shape[2] - 1) * (d[1] - 1)
576 + 2 * p[1]
577 ) // s[1] + 1
578 os = [ifm_shape[0], oh, ow, filter_shape[0]]
579 arg_list.append(
580 (
581 "st{}_pad{}_dilat{}_os{}".format(
582 "".join([str(x) for x in s]),
583 "".join([str(x) for x in p]),
584 "".join([str(x) for x in d]),
585 "x".join([str(x) for x in os]),
586 ),
587 [s, p, d, os],
588 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800589 )
Les Bell7aa69f42021-09-20 10:44:07 +0100590 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700591
592 return arg_list
593
594 @staticmethod
595 def agPad(testGen, opName, shapeList, dtype):
596 arg_list = []
597 rank = len(shapeList[0])
598
Les Bell7ffccce2021-07-28 15:37:02 +0100599 # Exhaustively test combinations of padding on each side of each dimension
600 # - the range of padding values is defined by pad_min and pad_max
601 # - for padding >9, the name format needs to be more distinctive
602 pad_min, pad_max = 0, 1
603 pad_values = [x for x in range(pad_min, pad_max + 1)]
604 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
605 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700606
Les Bell7ffccce2021-07-28 15:37:02 +0100607 for paddings in shape_pad_values:
608 name = "pad"
609 for r in range(rank):
610 before, after = paddings[r]
611 name = f"{name}{before}{after}"
612 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700613
614 return arg_list
615
616 @staticmethod
617 def agPooling(testGen, opName, shapeList, dtype):
618 arg_list = []
619
620 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800621 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700622
Les Bell7aa69f42021-09-20 10:44:07 +0100623 # Generate comprehensive argument lists
624 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
625 paddings = {x for x in itertools.product(*([p_vals] * 4))}
626 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
627 strides = {x for x in itertools.product(*([s_vals] * 2))}
628 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
629 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700630
Les Bell7aa69f42021-09-20 10:44:07 +0100631 # add some oversize argument values
632 bigStride = 7
633 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
634 bigKernel = 6
635 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
636 if max(shape) < 64:
637 # padding must be less than the kernel size
638 bigPadding = bigKernel - 1
639 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700640
Les Bell7aa69f42021-09-20 10:44:07 +0100641 # There are too many parameter combinations, so generate them sparsely
642 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
643 n = 0
644 for s in sorted(list(strides)):
645 for p in sorted(list(paddings)):
646 for k in sorted(list(kernels)):
647 if (n % sparsity == 0
648 # padding must not exceed the kernel size
649 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
650 # the padded shape must exceed the kernel size
651 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
652 ):
653 arg_list.append(
654 (
655 "st{}_kern{}_pad{}".format(
656 "".join([str(x) for x in s]),
657 "".join([str(x) for x in k]),
658 "".join([str(x) for x in p]),
659 ),
660 [s, p, k],
661 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800662 )
Les Bell7aa69f42021-09-20 10:44:07 +0100663 n += 1
664
Eric Kunzee5e26762020-10-13 16:11:07 -0700665 return arg_list
666
667 @staticmethod
668 def agCast(testGen, opName, shapeList, inDtype):
669 arg_list = []
670
671 # Enumerate the output types here
672 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800673 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700674 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800675 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700676 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800677 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700678 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800679 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700680 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800681 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700682 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800683 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700684
685 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800686 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700687
688 return arg_list
689
690 @staticmethod
691 def agRescale(testGen, opName, shapeList, inDtype):
692 arg_list = []
693
694 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100695 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
696 if inDtype == DType.UINT8 and dtype != DType.INT8:
697 # The only output dtype for UINT8 is INT8, skip all other combinations
698 continue
699 if inDtype != DType.INT8 and dtype == DType.UINT8:
700 # The only input dtype for UINT8 is INT8, skip all other combinations
701 continue
702
Kevin Cheng550ccc52021-03-03 11:21:43 -0800703 for scale32 in [False, True]:
704 for double_round in [False, True]:
705 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700706
707 if inDtype == DType.INT48 and scale32:
708 # Illegal condition. Must be scale32=False
709 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100710 if double_round and not scale32:
711 # Illegal condition. ERROR_IF(!scale32 && double_round)
712 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700713
Kevin Cheng550ccc52021-03-03 11:21:43 -0800714 arg_list.append(
715 (
716 "out{}_sc{}_dr{}_pc{}".format(
717 DTypeNames[dtype],
718 int(scale32),
719 int(double_round),
720 int(per_channel),
721 ),
722 [dtype, scale32, double_round, per_channel],
723 )
724 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700725
726 return arg_list
727
Kevin Chengaee1fac2020-11-11 13:54:06 -0800728 @staticmethod
729 def agMul(testGen, opName, shapeList, dtype):
730 arg_list = []
731
732 if dtype is DType.INT32:
733 for p in range(testGen.args.num_rand_permutations):
734
735 shift = testGen.randInt(0, 32)
736
Kevin Cheng550ccc52021-03-03 11:21:43 -0800737 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800738 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100739 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800740
741 return arg_list
742
743 @staticmethod
744 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
745 arg_list = []
746
Kevin Cheng550ccc52021-03-03 11:21:43 -0800747 arg_list.append(("roundTrue", [True]))
748 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800749
750 return arg_list
751
Eric Kunzee5e26762020-10-13 16:11:07 -0700752 # Helper function for reshape. Gets some factors of a larger number.
753 @staticmethod
754 def getFactors(val, start=1):
755 factors = []
756
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100757 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700758 if (val % i) == 0:
759 factors.append(i)
760
761 return factors
762
763 @staticmethod
764 def agReshape(testGen, opName, shapeList, dtype):
765 arg_list = []
766
767 origShape = shapeList[0]
768
769 totalElements = 1
770 for s in origShape:
771 totalElements *= s
772
773 # This code is NOT fast. Fortunately, the numbers are fairly small.
774 factors = TosaArgGen.getFactors(totalElements)
775
776 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100777 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800778 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700779 continue
780
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100781 found = True
782 # escape_counter breaks while loop if it continues on for too long
783 escape_counter = 0
784 while found:
785 newShape = []
786 # Generate newShape ensuring it isn't a duplicate
787 remainingElements = totalElements
788 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100789 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100790 # pick rank-1 factors
791 newShape.append(shuffledFactors[0])
792 remainingElements = remainingElements // shuffledFactors[0]
793 shuffledFactors = testGen.rng.permutation(
794 TosaArgGen.getFactors(remainingElements)
795 )
796 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700797
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100798 # Toss in a -1 sometimes
799 minusOne = testGen.randInt(0, newRank * 4)
800 if minusOne < newRank:
801 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700802
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100803 # Check for duplicates
804 found = False
805 for name, other_shape in arg_list:
806 if other_shape[0] == newShape:
807 found = True
808 break
809
810 escape_counter += 1
811 if escape_counter >= 100:
812 break
813
814 if not found:
815 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
817 return arg_list
818
Eric Kunzee5e26762020-10-13 16:11:07 -0700819 @staticmethod
820 def agTranspose(testGen, opName, shapeList, dtype):
821 arg_list = []
822
823 ifm_shape = shapeList[0]
824
Jeremy Johnsona6185572021-06-21 15:55:35 +0100825 # Get all permutations
826 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700827
Jeremy Johnsona6185572021-06-21 15:55:35 +0100828 # Limit to possible permutations from shape dimension or argument setting
829 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
Jeremy Johnsona6185572021-06-21 15:55:35 +0100831 # Get random permutation generator that uses all permutations
832 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700833
Jeremy Johnsona6185572021-06-21 15:55:35 +0100834 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700835 arg_list = [
836 ("perm{}".format(p), [random_permutations[p].tolist()])
837 for p in range(limit)
838 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700839 return arg_list
840
841 @staticmethod
842 def agSlice(testGen, opName, shapeList, dtype):
843 arg_list = []
844
845 ifm_shape = shapeList[0]
846 rank = len(ifm_shape)
847
848 for p in range(testGen.args.num_rand_permutations):
849 begin = []
850 size = []
851
Kevin Cheng550ccc52021-03-03 11:21:43 -0800852 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700853
854 for i in range(rank):
855 if ifm_shape[i] > 1:
856 begin.append(testGen.randInt(0, ifm_shape[i]))
857 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
858
859 # Invalid slice size?
860 if size[i] == 0:
861 valid = False
862 else:
863 begin.append(0)
864 size.append(1)
865
866 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800867 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700868 return arg_list
869
870 @staticmethod
871 def agTile(testGen, opName, shapeList, dtype):
872 arg_list = []
873
874 ifm_shape = shapeList[0]
875 rank = len(ifm_shape)
876
877 for p in range(testGen.args.num_rand_permutations):
878
879 # Pick a few random, but small multiple values
880 # because otherwise this has a tendency to generate
881 # enormous tensors
882 multiples = []
883 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100884 if ifm_shape[i] > 1000:
885 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
886 multiples.append(1)
887 elif max(ifm_shape) > 1000:
888 multiples.append(2)
889 else:
890 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800891 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700892
893 return arg_list
894
895 @staticmethod
Matthew Haddone86fd342021-09-07 16:12:21 +0100896 def agResize(testGen, opName, shapeList, dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700897 arg_list = []
898
899 ifm_shape = shapeList[0]
Matthew Haddon848efb42021-09-09 12:30:53 +0100900 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700901
902 # Exclude illegal {mode, type} configurations. Pick legal output types
Matthew Haddon848efb42021-09-09 12:30:53 +0100903 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100904 outputDTypeList = [DType.INT8]
Matthew Haddon848efb42021-09-09 12:30:53 +0100905 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800906 outputDTypeList = [DType.INT16]
Matthew Haddon848efb42021-09-09 12:30:53 +0100907 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100908 outputDTypeList = [DType.INT32]
Matthew Haddon848efb42021-09-09 12:30:53 +0100909 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800910 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800911 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800912 outputDTypeList = [DType.FLOAT]
Matthew Haddon848efb42021-09-09 12:30:53 +0100913 elif error_name == ErrorIf.WrongInputType:
914 # If an incorrect input type is used then we set a 'correct'
915 # output type to avoid other errors
916 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700917 else:
918 continue
919
920 for outputDType in outputDTypeList:
921 for perm in range(testGen.args.num_rand_permutations):
Eric Kunzee5e26762020-10-13 16:11:07 -0700922 # Randomly generate legal output dimensions and shift
923 # and then compute the stride and offset based on them
Matthew Haddone86fd342021-09-07 16:12:21 +0100924 # A output_dim of 1 will cause offset to exceed allowed range
925 # so minimum value 2 produced below
926 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
927 while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
928 output_dims[0] += 1
929 while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
930 output_dims[1] += 1
931
Kevin Cheng77d0f762020-11-24 10:26:32 -0800932 in_center_h = (ifm_shape[1] - 1) / 2.0
933 in_center_w = (ifm_shape[2] - 1) / 2.0
934 out_center_h = (output_dims[0] - 1) / 2.0
935 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700936
Kevin Cheng77d0f762020-11-24 10:26:32 -0800937 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
938 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
939 fp_offset_y = in_center_h - fp_stride_y * out_center_h
940 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700941
Kevin Cheng77d0f762020-11-24 10:26:32 -0800942 if outputDType == DType.FLOAT:
943 shift = 0
944 stride = [0, 0]
945 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800946 stride_fp = [fp_stride_y, fp_stride_x]
947 offset_fp = [fp_offset_y, fp_offset_x]
Matthew Haddone86fd342021-09-07 16:12:21 +0100948
949 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +0100950 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +0100951 testGen,
952 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +0100953 mode,
954 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +0100955 shapeList,
956 outputDType,
957 shift,
958 stride,
959 stride_fp,
960 offset,
961 offset_fp
962 )
Matthew Haddon848efb42021-09-09 12:30:53 +0100963 else:
964 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +0100965
Kevin Cheng550ccc52021-03-03 11:21:43 -0800966 arg_list.append(
967 (
968 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +0100969 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800970 output_dims[0],
971 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +0100972 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -0800973 stride_fp[0],
974 stride_fp[1],
975 offset_fp[0],
976 offset_fp[1],
977 ),
978 [
Matthew Haddon848efb42021-09-09 12:30:53 +0100979 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800980 stride,
981 offset,
982 shift,
983 stride_fp,
984 offset_fp,
985 output_dims,
986 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +0100987 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800988 ],
989 )
990 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800991 else:
992 shift = 11
993 unit = float(1 << shift)
994 stride_y = int(round(fp_stride_y * unit))
995 stride_x = int(round(fp_stride_x * unit))
996 offset_y = int(round(fp_offset_y * unit))
997 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700998
Kevin Cheng550ccc52021-03-03 11:21:43 -0800999 while (
Matthew Haddone86fd342021-09-07 16:12:21 +01001000 stride_y >= (16 << shift)
1001 or stride_x >= (16 << shift)
1002 or offset_y >= (16 << shift)
1003 or offset_x >= (16 << shift)
1004 or offset_y <= (-16 << shift)
1005 or offset_x <= (-16 << shift)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001006 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001007 shift = shift - 1
1008 unit = float(1 << shift)
1009 stride_y = int(round(fp_stride_y * unit))
1010 stride_x = int(round(fp_stride_x * unit))
1011 offset_y = int(round(fp_offset_y * unit))
1012 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -07001013
Kevin Cheng550ccc52021-03-03 11:21:43 -08001014 stride = [stride_y, stride_x]
1015 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -08001016
1017 stride_fp = [0.0, 0.0]
1018 offset_fp = [0.0, 0.0]
1019
Matthew Haddone86fd342021-09-07 16:12:21 +01001020 if error_name is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01001021 shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
Matthew Haddone86fd342021-09-07 16:12:21 +01001022 testGen,
1023 error_name,
Matthew Haddon848efb42021-09-09 12:30:53 +01001024 mode,
1025 dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001026 shapeList,
1027 outputDType,
1028 shift,
1029 stride,
1030 stride_fp,
1031 offset,
1032 offset_fp
1033 )
Matthew Haddon848efb42021-09-09 12:30:53 +01001034 else:
1035 outputDTypeNew = outputDType
Matthew Haddone86fd342021-09-07 16:12:21 +01001036
Kevin Cheng550ccc52021-03-03 11:21:43 -08001037 arg_list.append(
1038 (
1039 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Matthew Haddon848efb42021-09-09 12:30:53 +01001040 "N" if mode == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -08001041 shift,
1042 output_dims[0],
1043 output_dims[1],
Matthew Haddon848efb42021-09-09 12:30:53 +01001044 testGen.typeStr(outputDTypeNew),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001045 stride[0],
1046 stride[1],
1047 offset[0],
1048 offset[1],
1049 ),
1050 [
Matthew Haddon848efb42021-09-09 12:30:53 +01001051 mode,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001052 stride,
1053 offset,
1054 shift,
1055 stride_fp,
1056 offset_fp,
1057 output_dims,
1058 dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001059 outputDTypeNew,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001060 ],
1061 )
1062 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001063
1064 return arg_list
1065
1066 def agCondIf(testGen, opName, shapeList, dtype):
1067 # CondIf generates the condition values here.
1068 # Convert to tensors in the build function, along with the
1069 # then and else blocks
1070 arg_list = []
1071
1072 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001073 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001074
1075 return arg_list
1076
1077 def agWhileLoop(testGen, opName, shapeList, dtype):
1078 # While loop: 0 iterations, 1, more than 1
1079 arg_list = []
1080
1081 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001082 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001083
1084 return arg_list
1085
Matthew Haddone86fd342021-09-07 16:12:21 +01001086class TosaErrorIfArgGen:
1087
1088 @staticmethod
Matthew Haddon848efb42021-09-09 12:30:53 +01001089 def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
Matthew Haddone86fd342021-09-07 16:12:21 +01001090
1091 if outputDType == DType.FLOAT:
1092 if error_name == ErrorIf.StrideSmallerEqualZero:
1093 stride_fp = testGen.rng.random(size=[2]) - 2
1094 elif error_name == ErrorIf.ShiftNotZero:
1095 shift = testGen.rng.integers(1, 5)
1096 elif error_name == ErrorIf.StrideLargerDimension:
1097 shape = shapeList[0]
1098 transform_height = testGen.rng.choice([False, True])
1099 if transform_height:
1100 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
1101 else:
1102 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
1103 else:
1104 if error_name == ErrorIf.StrideSmallerEqualZero:
1105 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
1106 elif error_name == ErrorIf.ShiftSmallerOne:
1107 shift = testGen.rng.integers(-3, 1)
1108 if shift <= 0:
1109 stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1110 offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
1111 else:
1112 stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1113 offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
1114 elif error_name == ErrorIf.ShiftLargerEleven:
1115 shift = np.int16(testGen.rng.integers(12, 15))
1116 elif error_name == ErrorIf.StrideLargerDimension:
1117 shape = shapeList[0]
1118 transform_height = testGen.rng.choice([False, True])
1119 if transform_height:
1120 stride[0] = shape[1] + testGen.rng.integers(1, 10)
1121 else:
1122 stride[1] = shape[2] + testGen.rng.integers(1, 10)
1123 elif error_name == ErrorIf.StrideLargerEqualMax:
1124 stride = [(16 << shift) + 1, (16 << shift) + 1]
1125 elif error_name == ErrorIf.OffsetLargerEqualMax:
1126 offset = [(16 << shift) + 1, (16 << shift) + 1]
1127 elif error_name == ErrorIf.OffsetSmallerEqualMin:
1128 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
1129
Matthew Haddon848efb42021-09-09 12:30:53 +01001130 if error_name == ErrorIf.WrongOutputType:
1131 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1132 incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1133 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1134 incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
1135 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1136 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
1137 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1138 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
1139 elif dtype == DType.FLOAT:
1140 incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
1141 outputDType = testGen.rng.choice(a=incorrect_types)
Matthew Haddone86fd342021-09-07 16:12:21 +01001142
Matthew Haddon848efb42021-09-09 12:30:53 +01001143 return shift, stride, stride_fp, offset, offset_fp, outputDType
1144
1145 @staticmethod
1146 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
1147 # Mess up input/output tensors for ERROR_IF checks
1148 if error_name == "WrongInputList":
1149 add_input = testGen.rng.choice([True, False])
1150 if add_input:
1151 input_list.append('eiDummyInput')
1152 else:
1153 input_list = input_list[:-1]
1154 if error_name == "WrongOutputList":
1155 add_output = testGen.rng.choice([True, False])
1156 if add_output:
1157 output_list.append('eiDummyOutput')
1158 else:
1159 output_list = []
1160 return input_list, output_list
Matthew Haddone86fd342021-09-07 16:12:21 +01001161
1162class TosaErrorValidator:
1163
Matthew Haddon848efb42021-09-09 12:30:53 +01001164 @staticmethod
1165 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
1166 # Check ERROR_IF statements
1167
1168 for val_fcn in validator_fcns:
1169 val_result = val_fcn(True, **kwargs)
1170
1171 validator_name = val_result['error_name']
1172 error_result = val_result['error_result']
1173 error_reason = val_result['error_reason']
1174
1175 if error_result:
1176 if error_name == validator_name:
1177 serializer.setExpectedReturnCode(2, error_reason)
1178 else:
1179 print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
1180 return None # Return None to delete test if wrong ERROR_IF is hit
1181 else:
1182 if error_name == validator_name:
1183 print(f"No ERROR_IF hit for {error_name}")
1184 return None
1185
1186 @staticmethod
1187 def evWrongInputType(check=False, **kwargs):
1188 all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
1189
1190 # Find the unsupported input data types
1191 assert 'op' in kwargs
1192 op = kwargs['op']
1193 input_dtypes = op['types']
1194 wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
1195
1196 error_name = ErrorIf.WrongInputType
1197 param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
1198 error_result = False
1199 error_reason = "Input data type not supported for this operator"
1200
1201 if check:
1202 input_dtype = kwargs['input_dtype']
1203 if input_dtype not in input_dtypes:
1204 error_result = True
1205
1206 info_dict = {
1207 "error_name": error_name,
1208 "error_result": error_result,
1209 "error_reason": error_reason,
1210 "param_reqs": param_reqs
1211 }
1212 return info_dict
1213
1214 @staticmethod
1215 def evWrongOutputType(check=False, **kwargs):
1216 error_name = ErrorIf.WrongOutputType
1217 param_reqs = {"rank": None, "dtype": None, "shape": None}
1218 error_result = False
1219 error_reason = "Output data type not supported for this configuration of operator"
1220
1221 if check:
1222 input_dtype = kwargs['input_dtype']
1223 output_dtype = kwargs['output_dtype']
1224 mode = kwargs['mode']
1225
1226 if (
1227 (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
1228 (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
1229 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
1230 (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
1231 (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
1232 ):
1233 error_result = True
1234
1235 info_dict = {
1236 "error_name": error_name,
1237 "error_result": error_result,
1238 "error_reason": error_reason,
1239 "param_reqs": param_reqs
1240 }
1241 return info_dict
1242
1243 @staticmethod
1244 def evWrongRank(check=False, **kwargs):
1245 all_ranks = (1, 2, 3, 4, 5)
1246
1247 # Make a list of incorrect ranks
1248 assert 'op' in kwargs
1249 op = kwargs['op']
1250 rmin, rmax = op['rank']
1251 rank_range = range(rmin, rmax + 1)
1252 incorrect_ranks = list(set(all_ranks) - set(rank_range))
1253 # Set minimum incorrect rank to 3 to avoid index error
1254 if op['op'] == Op.RESIZE:
1255 incorrect_ranks = [3, 5]
1256
1257 error_name = ErrorIf.WrongRank
1258 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
1259 error_result = False
1260 error_reason = "Rank not supported for this operator"
1261
1262 if check:
1263 input_shape = kwargs['input_shape']
1264 if op['op'] == Op.RESIZE and len(input_shape.shape) != 4:
1265 error_result = True
1266
1267 info_dict = {
1268 "error_name": error_name,
1269 "error_result": error_result,
1270 "error_reason": error_reason,
1271 "param_reqs": param_reqs
1272 }
1273 return info_dict
1274
1275 @staticmethod
1276 def evWrongInputList(check=False, **kwargs):
1277 error_name = ErrorIf.WrongInputList
1278 param_reqs = {"rank": None, "dtype": None, "shape": None}
1279 error_result = False
1280 error_reason = "Op input list does not match expected input"
1281
1282 if check:
1283 op = kwargs['op']
1284 input_list = kwargs['input_list']
1285 num_operands = kwargs['num_operands']
1286 if len(input_list) != num_operands:
1287 error_result = True
1288
1289 info_dict = {
1290 "error_name": error_name,
1291 "error_result": error_result,
1292 "error_reason": error_reason,
1293 "param_reqs": param_reqs
1294 }
1295 return info_dict
1296
1297 @staticmethod
1298 def evWrongOutputList(check=False, **kwargs):
1299 error_name = ErrorIf.WrongOutputList
1300 param_reqs = {"rank": None, "dtype": None, "shape": None}
1301 error_result = False
1302 error_reason = "Op output list does not match expected output"
1303
1304 if check:
1305 output_list = kwargs['output_list']
1306 # Note this will be incorrect if an operator returns more than one output
1307 if len(output_list) != 1:
1308 error_result = True
1309
1310 info_dict = {
1311 "error_name": error_name,
1312 "error_result": error_result,
1313 "error_reason": error_reason,
1314 "param_reqs": param_reqs
1315 }
1316 return info_dict
Matthew Haddone86fd342021-09-07 16:12:21 +01001317
1318 @staticmethod
1319 def evMaxDimExceeded(check=False, **kwargs):
1320 error_name = ErrorIf.MaxDimExceeded
Matthew Haddon848efb42021-09-09 12:30:53 +01001321 param_reqs = {
1322 "rank": [4,4],
1323 "dtype": [DType.INT8],
1324 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
1325 }
Matthew Haddone86fd342021-09-07 16:12:21 +01001326 error_result = False
1327 error_reason = "At least one maximum dimension is larger than 16384"
1328
1329 if check:
1330 input_shape = kwargs['input_shape'].shape
1331 output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
1332 if ((input_shape[1] > 16384) or
1333 (input_shape[2] > 16384) or
1334 (output_shape[0] > 16384) or
1335 (output_shape[1] > 16384)):
1336 error_result = True
1337
1338 info_dict = {
1339 "error_name": error_name,
1340 "error_result": error_result,
1341 "error_reason": error_reason,
1342 "param_reqs": param_reqs
1343 }
1344 return info_dict
1345
1346 @staticmethod
1347 def evStrideSmallerEqualZero(check=False, **kwargs):
1348 error_name = ErrorIf.StrideSmallerEqualZero
1349 param_reqs = {"rank": None, "dtype": None, "shape": None}
1350 error_result = False
1351 error_reason = "Stride value smaller than or equal zero"
1352
1353 if check:
1354 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001355 output_dtype = kwargs['output_dtype']
1356 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
1357 stride = kwargs['stride'] # Work around wrong input/output type tests
1358 elif output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001359 stride = kwargs['stride_fp']
Matthew Haddon848efb42021-09-09 12:30:53 +01001360 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
1361 stride = kwargs['stride_fp'] # Work around wrong input/output type tests
Matthew Haddone86fd342021-09-07 16:12:21 +01001362 else:
1363 stride = kwargs['stride']
1364
1365 if min(stride) <= 0:
1366 error_result = True
1367
1368 info_dict = {
1369 "error_name": error_name,
1370 "error_result": error_result,
1371 "error_reason": error_reason,
1372 "param_reqs": param_reqs
1373 }
1374 return info_dict
1375
1376 @staticmethod
1377 def evStrideLargerEqualMax(check=False, **kwargs):
1378 error_name = ErrorIf.StrideLargerEqualMax
1379 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1380 error_result = False
1381 error_reason = "Stride value larger than or equal to maximum value"
1382
1383 if check:
1384 shift = kwargs['shift']
1385 input_dtype = kwargs['input_dtype']
1386 stride = kwargs['stride']
1387 if input_dtype in [DType.INT8, DType.INT16]:
1388 if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
1389 error_result = True
1390 elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
1391 error_result = True
1392
1393 info_dict = {
1394 "error_name": error_name,
1395 "error_result": error_result,
1396 "error_reason": error_reason,
1397 "param_reqs": param_reqs
1398 }
1399 return info_dict
1400
1401
1402 @staticmethod
1403 def evStrideLargerDimension(check=False, **kwargs):
1404 error_name = ErrorIf.StrideLargerDimension
1405 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1406 error_result = False
1407 error_reason = "Stride value larger than or equal to H/W dimension"
1408
1409 if check:
1410 shape = kwargs['input_shape'].shape
1411 input_dtype = kwargs['input_dtype']
1412 stride = kwargs['stride_fp']
1413
1414 if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
1415 error_result = True
1416
1417 info_dict = {
1418 "error_name": error_name,
1419 "error_result": error_result,
1420 "error_reason": error_reason,
1421 "param_reqs": param_reqs
1422 }
1423 return info_dict
1424
1425
1426 @staticmethod
1427 def evOffsetSmallerEqualMin(check=False, **kwargs):
1428 error_name = ErrorIf.OffsetSmallerEqualMin
1429 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1430 error_result = False
1431 error_reason = "Offset value smaller than or equal to minimum value"
1432
1433 if check:
1434 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001435 output_dtype = kwargs['output_dtype']
1436 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001437 offset = kwargs['offset_fp']
1438 else:
1439 offset = kwargs['offset']
1440
1441 if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
1442 error_result = True
1443 elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
1444 error_result = True
1445
1446 info_dict = {
1447 "error_name": error_name,
1448 "error_result": error_result,
1449 "error_reason": error_reason,
1450 "param_reqs": param_reqs
1451 }
1452 return info_dict
1453
1454 @staticmethod
1455 def evOffsetLargerEqualMax(check=False, **kwargs):
1456 error_name = ErrorIf.OffsetLargerEqualMax
1457 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1458 error_result = False
1459 error_reason = "Offset value larger than or equal to maximum value"
1460
1461 if check:
1462 shift = kwargs['shift']
Matthew Haddon848efb42021-09-09 12:30:53 +01001463 output_dtype = kwargs['output_dtype']
1464 if output_dtype == DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001465 offset = kwargs['offset_fp']
1466 else:
1467 offset = kwargs['offset']
1468
1469 if shift >= 0:
1470 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
1471 error_result = True
1472
1473 if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
1474 error_result = True
1475 elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
1476 error_result = True
1477
1478 info_dict = {
1479 "error_name": error_name,
1480 "error_result": error_result,
1481 "error_reason": error_reason,
1482 "param_reqs": param_reqs
1483 }
1484 return info_dict
1485
1486 @staticmethod
1487 def evShiftNotZero(check=False, **kwargs):
1488 error_name = ErrorIf.ShiftNotZero
1489 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
1490 error_result = False
1491 error_reason = "Shift value must be zero for float input"
1492
1493 if check:
1494 shift = kwargs['shift']
1495 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001496 output_dtype = kwargs['output_dtype']
1497 if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
Matthew Haddone86fd342021-09-07 16:12:21 +01001498 error_result = True
1499
1500 info_dict = {
1501 "error_name": error_name,
1502 "error_result": error_result,
1503 "error_reason": error_reason,
1504 "param_reqs": param_reqs
1505 }
1506 return info_dict
1507
1508
1509 @staticmethod
1510 def evShiftSmallerOne(check=False, **kwargs):
1511 error_name = ErrorIf.ShiftSmallerOne
1512 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1513 error_result = False
1514 error_reason = "Shift value smaller than one"
1515
1516 if check:
1517 shift = kwargs['shift']
1518 input_dtype = kwargs['input_dtype']
Matthew Haddon848efb42021-09-09 12:30:53 +01001519 output_dtype = kwargs['output_dtype']
1520 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
Matthew Haddone86fd342021-09-07 16:12:21 +01001521 error_result = True
1522
1523 info_dict = {
1524 "error_name": error_name,
1525 "error_result": error_result,
1526 "error_reason": error_reason,
1527 "param_reqs": param_reqs
1528 }
1529 return info_dict
1530
1531 @staticmethod
1532 def evShiftLargerEleven(check=False, **kwargs):
1533 error_name = ErrorIf.ShiftLargerEleven
1534 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
1535 error_result = False
1536 error_reason = "Shift value larger than eleven"
1537
1538 if check:
1539 shift = kwargs['shift']
1540 if shift > 11:
1541 error_result = True
1542
1543 info_dict = {
1544 "error_name": error_name,
1545 "error_result": error_result,
1546 "error_reason": error_reason,
1547 "param_reqs": param_reqs
1548 }
1549 return info_dict
1550
1551
Matthew Haddonb724efc2021-08-25 16:40:29 +01001552class TosaInvalidValidator:
1553
1554 @staticmethod
1555 def ivWrongDataTypeOrModeResize(**kwargs):
1556 input_dtype = kwargs["input_dtype"]
1557 args = kwargs["args"]
1558 mode = args[0]
1559 stride = args[1]
1560 stride_fp = args[4]
1561 output_dtype = args[8]
1562
1563 if mode == ResizeMode.BILINEAR:
1564 # Invalid output data type / Invalid input datatype
1565 return (
1566 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
1567 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
1568 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
1569 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1570 )
1571 elif mode == ResizeMode.NEAREST:
1572 # Invalid output data type / Invalid input datatype
1573 return (
1574 (input_dtype != output_dtype) or
1575 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1576 )
1577 else:
1578 # Invalid resize mode
1579 return True
1580
1581 @staticmethod
1582 def ivBadStride(**kwargs):
1583 input_dtype = kwargs["input_dtype"]
1584 args = kwargs["args"]
1585 stride_x = args[1][0]
1586 stride_y = args[1][1]
1587 stride_fp_x = args[4][0]
1588 stride_fp_y = args[4][1]
1589
1590 if input_dtype == DType.FLOAT:
1591 if stride_fp_x <= 0 or stride_fp_y <= 0:
1592 # Negative or zero stride
1593 return True
1594 else:
1595 if stride_x <= 0 or stride_y <= 0:
1596 # Negative or zero stride
1597 return True
1598 return False
1599
1600
Matthew Haddonb724efc2021-08-25 16:40:29 +01001601 @staticmethod
1602 def ivHeightWidthSmallerZero(**kwargs):
1603 opName = kwargs['opName']
1604
1605 inputShapes = kwargs['shapeList']
1606 input = inputShapes[0]
1607 if not opName.endswith("pool2d"):
1608 filter = inputShapes[1]
1609
1610 args = kwargs['args']
1611 strides = args[0]
1612 padding = args[1]
1613 dilations = args[2]
1614 if opName.endswith("pool2d"):
1615 kernel = args[2]
1616
1617 if opName.startswith('conv2d'):
1618 h = (
1619 input[1]
1620 - filter[1]
1621 - (filter[1] - 1) * (dilations[0] - 1)
1622 + padding[0]
1623 + padding[1]
1624 ) // strides[0] + 1
1625
1626 w = (
1627 input[2]
1628 - filter[2]
1629 - (filter[2] - 1) * (dilations[1] - 1)
1630 + padding[2]
1631 + padding[3]
1632 ) // strides[1] + 1
1633 elif opName.startswith("depthwise_conv2d"):
1634 h = (
1635 input[1]
1636 - filter[0]
1637 - (filter[0] - 1) * (dilations[0] - 1)
1638 + padding[0]
1639 + padding[1]
1640 ) // strides[0] + 1
1641
1642 w = (
1643 input[2]
1644 - filter[1]
1645 - (filter[1] - 1) * (dilations[1] - 1)
1646 + padding[2]
1647 + padding[3]
1648 ) // strides[1] + 1
1649 elif opName.endswith("pool2d"):
1650 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
1651 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
1652 else:
1653 assert False, "Unrecognized Op"
1654
1655 if h <= 0 or w <= 0:
1656 # Invalid parameter combination
1657 return True
1658 return False
1659
1660 @staticmethod
1661 def ivNonPositiveOutputShape(**kwargs):
1662 args = kwargs['args']
1663 output_shape = args[3]
1664 if output_shape[1] <= 0 or output_shape[2] <= 0:
1665 # Negative output shape
1666 return True
1667 return False
1668
1669
Kevin Cheng550ccc52021-03-03 11:21:43 -08001670
Eric Kunzee5e26762020-10-13 16:11:07 -07001671class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001672 # Maximum rank of tensor supported by test generator.
1673 TOSA_TENSOR_MAX_RANK = 6
1674
Eric Kunzee5e26762020-10-13 16:11:07 -07001675 def __init__(self, args):
1676 self.args = args
1677 self.basePath = args.output_dir
1678 self.random_seed = args.random_seed
1679 self.ser = None
1680 self.rng = np.random.default_rng(self.random_seed)
1681 self.createDynamicOpLists()
1682 self.initOpListDefaults()
1683 self.quantGen = TosaQuantGen()
1684 # Force makeShape to do a specific starting shape
1685 self.targetted_shape = None
1686
1687 def createSerializer(self, opName, testPath):
1688 self.testPath = os.path.join(opName, testPath)
1689
1690 fullPath = os.path.join(self.basePath, self.testPath)
1691 os.makedirs(fullPath, exist_ok=True)
1692 self.ser = ts.TosaSerializer(fullPath)
1693
1694 def getSerializer(self):
1695 return self.ser
1696
1697 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001698 with open(
1699 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
1700 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07001701 fd.write(self.ser.serialize())
1702
Kevin Cheng550ccc52021-03-03 11:21:43 -08001703 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
1704 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07001705
Matthew Haddon74567092021-07-16 15:38:20 +01001706 def resetRNG(self, seed=None):
1707 if seed == None:
1708 seed = self.random_seed + 1
1709 self.rng = np.random.default_rng(seed)
1710
Eric Kunzee5e26762020-10-13 16:11:07 -07001711 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07001712 if dtype == DType.BOOL:
1713 np_dt = np.bool
1714 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07001715 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001716 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001717 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001718 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001719 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
1720 elif dtype == DType.UINT8:
1721 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001722 elif dtype == DType.INT16:
1723 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
1724 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001725 return np.int32(
1726 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
1727 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001728 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001729 return np.int64(
1730 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
1731 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001732 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001733 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001734 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001735 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001736
Kevin Cheng989cb052021-04-28 16:29:44 -07001737 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001738 placeholders = []
1739
Kevin Cheng989cb052021-04-28 16:29:44 -07001740 assert len(shape_list) == len(dtype_list)
1741
1742 for idx, shape in enumerate(shape_list):
1743 arr = self.getRandTensor(shape, dtype_list[idx])
1744 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001745
1746 return placeholders
1747
Kevin Cheng989cb052021-04-28 16:29:44 -07001748 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001749 consts = []
1750
Kevin Cheng989cb052021-04-28 16:29:44 -07001751 assert len(shape_list) == len(dtype_list)
1752
1753 for idx, shape in enumerate(shape_list):
1754 arr = self.getRandTensor(shape, dtype_list[idx])
1755 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001756
1757 return consts
1758
1759 def makeShape(self, rank):
1760 if self.targetted_shape:
1761 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001762 return np.int32(
1763 self.rng.integers(
1764 low=self.args.tensor_shape_range[0],
1765 high=self.args.tensor_shape_range[1],
1766 size=rank,
1767 )
1768 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001769
1770 def setTargetShape(self, shape):
1771 self.targetted_shape = shape
1772
1773 def randInt(self, low=0, high=256):
1774 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
1775
1776 def getRandNumberDType(self, dtype):
1777 if dtype == DType.FLOAT:
1778 return self.rng.random()
1779 elif dtype == DType.BOOL:
1780 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07001781 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001782 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001783 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001784 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001785 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07001786 elif dtype == DType.INT16:
1787 low, high = (-32768, 32768)
1788 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001789 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001790 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001791 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001792 # Special size
1793 return np.int64(self.rng.integers(low, high, size=1))[0]
1794 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001795 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001796
1797 return np.int32(self.rng.integers(low, high, size=1))[0]
1798
1799 def shapeStr(self, shape):
1800
1801 sStr = []
1802 # Convert to strings
1803 for i in shape:
1804 sStr.append(str(i))
1805
Kevin Cheng550ccc52021-03-03 11:21:43 -08001806 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001807
1808 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001809 if isinstance(t, list):
1810 assert len(t) >= 2
1811 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001812 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001813 if t == DType.BOOL:
1814 return "b"
1815 elif t == DType.INT4:
1816 return "i4"
1817 elif t == DType.INT8:
1818 return "i8"
1819 elif t == DType.UINT8:
1820 return "u8"
1821 elif t == DType.INT16:
1822 return "i16"
1823 elif t == DType.INT32:
1824 return "i32"
1825 elif t == DType.INT48:
1826 return "i48"
1827 elif t == DType.FLOAT:
1828 return "float"
1829 else:
1830 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001831
1832 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001833 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001834 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001835 return 4
1836 elif t == DType.INT8:
1837 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001838 elif t == DType.UINT8:
1839 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001840 elif t == DType.INT16:
1841 return 16
1842 elif t == DType.INT32:
1843 return 32
1844 elif t == DType.INT48:
1845 return 48
1846 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001847 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001848
1849 # Argument generators
1850 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1851 # Where the string descriptor is used to generate the test name and
1852 # The build_fcn_arg_list is expanded and passed to the operator test
1853 # build function
1854
Kevin Cheng550ccc52021-03-03 11:21:43 -08001855 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001856 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01001857 # build_placeholder returns an int, ABS/other ops does not
1858 if isinstance(op, int):
1859 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1860 else:
1861 self.ser.addOperator(op['op'], [a.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001862 return result_tens
1863
1864 def build_binary_broadcast(self, op, a, b):
1865 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001866 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001867 return result_tens
1868
1869 def build_binary_nonbroadcast(self, op, a, b):
1870 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001871 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001872 return result_tens
1873
Kevin Chengaee1fac2020-11-11 13:54:06 -08001874 def build_arithmetic_right_shift(self, op, a, b, round):
1875 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1876
1877 attr = ts.TosaSerializerAttribute()
1878 attr.ArithmeticRightShiftAttribute(round)
1879
Matthew Haddon848efb42021-09-09 12:30:53 +01001880 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001881 return result_tens
1882
1883 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001884 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1885
1886 # Special for multiply:
1887 # Force the result to INT32 for INT types
1888 if a.dtype != DType.FLOAT:
1889 result_tens.setDtype(DType.INT32)
1890
Kevin Chengaee1fac2020-11-11 13:54:06 -08001891 attr = ts.TosaSerializerAttribute()
1892 attr.MulAttribute(shift)
1893
Matthew Haddon848efb42021-09-09 12:30:53 +01001894 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001895 return result_tens
1896
1897 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001898 # Constant size depending on type, random values
1899 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07001900 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001901 table_arr = self.getRandTensor([513], table_dtype)
1902 else:
1903 assert a.dtype == DType.INT8
1904 table_dtype = DType.INT8
1905 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001906
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001907 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1908 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01001909 self.ser.addOperator(op['op'], [a.name, table_tens.name], [result_tens.name], None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001910
1911 return result_tens
1912
1913 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001914 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001915 self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001916 return result_tens
1917
1918 def build_comparison(self, op, a, b):
1919 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01001920 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001921 return result_tens
1922
1923 def build_argmax(self, op, a, axis):
1924 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1925
1926 attr = ts.TosaSerializerAttribute()
1927 attr.AxisAttribute(axis)
1928
Matthew Haddon848efb42021-09-09 12:30:53 +01001929 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001930 return result_tens
1931
Matthew Haddonb724efc2021-08-25 16:40:29 +01001932 def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001933 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1934
1935 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001936 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001937
Matthew Haddon848efb42021-09-09 12:30:53 +01001938 self.ser.addOperator(op['op'], [input.name], [result_tens.name], attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001939 return result_tens
1940
1941 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001942 assert len(padding) == 4
1943 result_tens = OutputShaper.conv2dOp(
1944 self.ser, ifm, filter, strides, padding, dilations
1945 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001946
1947 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001948 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001949
Kevin Cheng550ccc52021-03-03 11:21:43 -08001950 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01001951 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001952 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001953 return result_tens
1954
Kevin Cheng1533b852021-09-01 12:51:58 -07001955 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
1956 assert len(padding) == 6
1957 result_tens = OutputShaper.conv3dOp(
1958 self.ser, ifm, filter, strides, padding, dilations
1959 )
1960
1961 attr = ts.TosaSerializerAttribute()
1962 attr.ConvAttribute(padding, strides, dilations)
1963
1964 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01001965 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng1533b852021-09-01 12:51:58 -07001966 )
1967 return result_tens
1968
Kevin Cheng550ccc52021-03-03 11:21:43 -08001969 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001970 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001971 ):
1972 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001973 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1974
1975 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001976 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07001977
Kevin Cheng550ccc52021-03-03 11:21:43 -08001978 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01001979 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001980 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001981 return result_tens
1982
Kevin Cheng550ccc52021-03-03 11:21:43 -08001983 def build_depthwise_conv2d(
1984 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1985 ):
1986 result_tens = OutputShaper.depthwiseConv2dOp(
1987 self.ser, ifm, filter, strides, padding, dilations
1988 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001989
1990 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001991 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001992
Kevin Cheng550ccc52021-03-03 11:21:43 -08001993 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01001994 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001995 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001996 return result_tens
1997
1998 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1999 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
2000
Kevin Cheng550ccc52021-03-03 11:21:43 -08002001 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002002 op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002003 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002004 return result_tens
2005
2006 def build_matmul(self, op, a, b, qinfo):
2007 result_tens = OutputShaper.matmulOp(self.ser, a, b)
Matthew Haddon848efb42021-09-09 12:30:53 +01002008 self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002009 return result_tens
2010
2011 def build_reduce(self, op, a, axis):
2012 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
2013
2014 attr = ts.TosaSerializerAttribute()
2015 attr.AxisAttribute(axis)
2016
Matthew Haddon848efb42021-09-09 12:30:53 +01002017 self.ser.addOperator(op['op'], [a.name], result_tens.name, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002018 return result_tens
2019
2020 def build_clamp(self, op, a):
2021 result_tens = OutputShaper.unaryOp(self.ser, a)
2022
2023 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01002024 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07002025
2026 if a.dtype == DType.FLOAT:
2027 attr.ClampAttribute(0, 0, min(v), max(v))
2028 else:
2029 attr.ClampAttribute(min(v), max(v), 0, 0)
2030
Matthew Haddon848efb42021-09-09 12:30:53 +01002031 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002032 return result_tens
2033
2034 def build_leaky_relu(self, op, a):
2035 result_tens = OutputShaper.unaryOp(self.ser, a)
2036 attr = ts.TosaSerializerAttribute()
2037
2038 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
2039
Matthew Haddon848efb42021-09-09 12:30:53 +01002040 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002041 return result_tens
2042
2043 # Needs an additional type/input
2044 def build_prelu(self, op, a):
2045 result_tens = OutputShaper.unaryOp(self.ser, a)
2046
Matthew Haddon848efb42021-09-09 12:30:53 +01002047 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002048 return result_tens
2049
Eric Kunzee5e26762020-10-13 16:11:07 -07002050 def build_sigmoid(self, op, a):
2051 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01002052 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002053 return result_tens
2054
2055 def build_tanh(self, op, a):
2056 result_tens = OutputShaper.unaryOp(self.ser, a)
Matthew Haddon848efb42021-09-09 12:30:53 +01002057 self.ser.addOperator(op['op'], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002058 return result_tens
2059
Matthew Haddon818ab902021-07-27 09:12:49 +01002060 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07002061 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01002062
2063 # To store variable length list of input tensors we need to store axis along with it
2064 axis = a[-1]
2065 a = a[:-1]
2066
2067 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07002068
2069 attr = ts.TosaSerializerAttribute()
2070 attr.AxisAttribute(axis)
2071
Matthew Haddon818ab902021-07-27 09:12:49 +01002072 input_tensor_names = []
2073 for tensor in a:
2074 input_tensor_names.append(tensor.name)
2075
Matthew Haddon848efb42021-09-09 12:30:53 +01002076 self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
2077 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002078
2079 def build_pad(self, op, a, padding, qinfo):
2080 result_tens = OutputShaper.padOp(self.ser, a, padding)
2081
2082 # Need to turn the padding array into a TOSA tensor here.
2083 # This is one of the few tensor operands that does not get
2084 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08002085 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07002086
Kevin Cheng550ccc52021-03-03 11:21:43 -08002087 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002088 op['op'], [a.name, padding_tens.name], [result_tens.name], None, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08002089 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002090 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07002091
2092 def build_reshape(self, op, a, newShape):
2093 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
2094
2095 attr = ts.TosaSerializerAttribute()
2096 attr.ReshapeAttribute(newShape)
2097
Matthew Haddon848efb42021-09-09 12:30:53 +01002098 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002099 return result_tens
2100
2101 def build_reverse(self, op, a, axis):
2102 result_tens = OutputShaper.unaryOp(self.ser, a)
2103
2104 attr = ts.TosaSerializerAttribute()
2105 attr.AxisAttribute(axis)
2106
Matthew Haddon848efb42021-09-09 12:30:53 +01002107 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002108 return result_tens
2109
2110 def build_transpose(self, op, a, perms):
2111 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
2112
Kevin Cheng550ccc52021-03-03 11:21:43 -08002113 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07002114
Matthew Haddon848efb42021-09-09 12:30:53 +01002115 self.ser.addOperator(op['op'], [a.name, perms_tens.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002116 return result_tens
2117
2118 def build_slice(self, op, a, begin, size):
2119 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
2120
2121 attr = ts.TosaSerializerAttribute()
2122 attr.SliceAttribute(begin, size)
2123
Matthew Haddon848efb42021-09-09 12:30:53 +01002124 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002125 return result_tens
2126
2127 def build_tile(self, op, a, multiples):
2128 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
2129
2130 attr = ts.TosaSerializerAttribute()
2131 attr.TileAttribute(multiples)
2132
Matthew Haddon848efb42021-09-09 12:30:53 +01002133 self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002134 return result_tens
2135
Kevin Cheng77d0f762020-11-24 10:26:32 -08002136 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07002137
2138 # Create a new indicies tensor
2139 # here with data that doesn't exceed the dimensions of the values tensor
2140
Kevin Cheng550ccc52021-03-03 11:21:43 -08002141 K = values.shape[1] # K
2142 W = self.randInt(
2143 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
2144 ) # W
2145 indicies_arr = np.int32(
2146 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
2147 ) # (N, W)
2148 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002149
Kevin Cheng77d0f762020-11-24 10:26:32 -08002150 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07002151
Matthew Haddon848efb42021-09-09 12:30:53 +01002152 self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002153
2154 return result_tens
2155
Kevin Cheng77d0f762020-11-24 10:26:32 -08002156 def build_scatter(self, op, values_in, input):
2157
2158 # Create a new indicies tensor
2159 # here with data that doesn't exceed the dimensions of the values_in tensor
2160
Kevin Cheng550ccc52021-03-03 11:21:43 -08002161 K = values_in.shape[1] # K
2162 W = input.shape[1] # W
2163 indicies_arr = np.int32(
2164 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
2165 ) # (N, W)
2166 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002167
2168 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
2169
Kevin Cheng550ccc52021-03-03 11:21:43 -08002170 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002171 op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002172 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08002173
2174 return result_tens
2175
Matthew Haddon848efb42021-09-09 12:30:53 +01002176
Kevin Cheng550ccc52021-03-03 11:21:43 -08002177 def build_resize(
2178 self,
2179 op,
2180 input,
2181 mode,
2182 stride,
2183 offset,
2184 shift,
2185 stride_fp,
2186 offset_fp,
2187 output_dims,
2188 input_dtype,
2189 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002190 validator_fcns,
2191 error_name = None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002192 ):
2193 result_tens = OutputShaper.resizeOp(
2194 self.ser,
2195 input,
2196 mode,
2197 stride,
2198 offset,
2199 shift,
2200 stride_fp,
2201 offset_fp,
2202 output_dims,
2203 input_dtype,
2204 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01002205 error_name
Kevin Cheng550ccc52021-03-03 11:21:43 -08002206 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002207
Matthew Haddon848efb42021-09-09 12:30:53 +01002208 # Invalidate Input/Output list for error if checks.
2209 input_list = [input.name]
2210 output_list = [result_tens.name]
2211 pCount, cCount = op["operands"]
2212 num_operands = pCount + cCount
2213 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
Matthew Haddone86fd342021-09-07 16:12:21 +01002214
Matthew Haddon848efb42021-09-09 12:30:53 +01002215 TosaErrorValidator.evValidateErrorIfs(
2216 self.ser,
2217 validator_fcns,
2218 error_name,
2219 op=op,
2220 mode=mode,
2221 shift=shift,
2222 input_dtype=input_dtype,
2223 output_dtype=output_dtype,
2224 input_shape=input,
2225 output_shape=output_dims,
2226 offset=offset,
2227 offset_fp=offset_fp,
2228 stride=stride,
2229 stride_fp=stride_fp,
2230 input_list=input_list,
2231 output_list=output_list,
2232 num_operands=num_operands,
2233 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002234
Eric Kunzee5e26762020-10-13 16:11:07 -07002235 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08002236
Kevin Cheng550ccc52021-03-03 11:21:43 -08002237 attr.ResizeAttribute(
2238 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
2239 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002240
Matthew Haddon848efb42021-09-09 12:30:53 +01002241 self.ser.addOperator(op['op'], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002242 return result_tens
2243
2244 def build_identityn(self, op, val, val2):
2245
Kevin Cheng550ccc52021-03-03 11:21:43 -08002246 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002247 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002248 self.ser.addOperator(
2249 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2250 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002251 return result_tens
2252
Kevin Cheng17e92022021-10-01 14:33:33 -07002253 def build_const(self, op, val):
2254 self.ser.addOutputTensor(val)
2255 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07002256
2257 # Type Conversion
2258 def build_cast(self, op, val, out_dtype):
2259 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
Matthew Haddon848efb42021-09-09 12:30:53 +01002260 self.ser.addOperator(op['op'], [val.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002261 return result_tens
2262
2263 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
2264 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
2265
2266 if per_channel:
2267 nc = val.shape[-1]
2268 else:
2269 nc = 1
2270
2271 in_type_width = self.typeWidth(val.dtype)
2272 out_type_width = self.typeWidth(out_dtype)
2273
Kevin Cheng3a478572021-01-22 17:21:02 -08002274 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002275 input_zp = self.randInt(-128, 128)
2276 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002277 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002278 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002279 in_type_width = in_type_width + 1
2280 else:
2281 input_zp = 0
2282
Kevin Cheng3a478572021-01-22 17:21:02 -08002283 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002284 output_zp = self.randInt(-128, 128)
2285 out_type_width = out_type_width + 1
2286 elif out_dtype == DType.UINT8:
2287 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07002288 out_type_width = out_type_width + 1
2289 else:
2290 output_zp = 0
2291
2292 # Calculate scale based on:
2293 # scale = a *(2^output_width)/(2^input_width))
2294
2295 a = np.float32(self.rng.random(size=[nc]))
2296 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2297
2298 if scale32:
2299 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002300 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002301 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2302 else:
2303 # Cap the scaling at 2^15 - 1 for scale16
2304 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2305
Kevin Cheng550ccc52021-03-03 11:21:43 -08002306 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002307
2308 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2309 shift_arr = np.int32(np.zeros(shape=[nc]))
2310
2311 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002312 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2313 scale_arr[i], scale32
2314 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002315
Kevin Cheng550ccc52021-03-03 11:21:43 -08002316 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07002317
2318 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002319 attr.RescaleAttribute(
2320 input_zp,
2321 output_zp,
2322 multiplier_arr,
2323 shift_arr,
2324 scale32,
2325 double_round,
2326 per_channel,
2327 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002328
Matthew Haddon848efb42021-09-09 12:30:53 +01002329 self.ser.addOperator(op['op'], [val.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002330 return result_tens
2331
2332 def build_cond_if_const(self, op, then_tens, else_tens, cond):
2333 # For cond_if with constants, we're supplied with then/else tensors that we ignore
2334 # (except for the generated shap) and the condition. Build Then/Else blocks
2335 # and fill them with const nodes for the body.
2336
2337 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002338 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002339
2340 # Make then/else tensors
2341 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01002342 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2343 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002344
2345 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002346 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002347
2348 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002349 then_block = "THEN_BLOCK"
2350 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002351 attr = ts.TosaSerializerAttribute()
2352 attr.CondIfAttribute(then_block, else_block)
2353
2354 # Finally, build the op and the two blocks
Matthew Haddon848efb42021-09-09 12:30:53 +01002355 self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002356
2357 self.ser.startBasicBlock(then_block)
2358 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002359 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002360 self.ser.addOutputTensor(then_tens)
2361
2362 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002363 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002364 self.ser.addOutputTensor(else_tens)
2365
2366 return result_tens
2367
2368 def build_cond_if_binary(self, op, a, b, cond):
2369 # For cond_if with a binary op in the then/else blocks, take a and b and
2370 # alternately add or subtract them based on the condition
2371
2372 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002373 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07002374
Kevin Cheng550ccc52021-03-03 11:21:43 -08002375 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002376
2377 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002378 then_block = "THEN_BLOCK"
2379 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002380 attr = ts.TosaSerializerAttribute()
2381 attr.CondIfAttribute(then_block, else_block)
2382
2383 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002384 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002385 op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002386 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002387
2388 self.ser.startBasicBlock(then_block)
2389 self.ser.addInputTensor(a)
2390 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002391 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002392 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
2393
2394 self.ser.startBasicBlock(else_block)
2395 self.ser.addInputTensor(a)
2396 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002397 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002398 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
2399
2400 return result_tens
2401
2402 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002403 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002404
Kevin Cheng550ccc52021-03-03 11:21:43 -08002405 cond_block = "COND_BLOCK"
2406 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002407
2408 attr = ts.TosaSerializerAttribute()
2409 attr.WhileLoopAttribute(cond_block, body_block)
2410
2411 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002412 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002413 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002414 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002415
2416 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002417 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2418 a_out = self.ser.addIntermediate(a.shape, a.dtype)
2419 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002420
2421 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002422 self.ser.addOperator(
Matthew Haddon848efb42021-09-09 12:30:53 +01002423 op['op'],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002424 [iter.name, a.name, acc.name],
2425 [iter_out.name, a_out.name, acc_out.name],
2426 attr,
2427 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002428 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002429
2430 # COND block (input: iter, output: cond_tens )
2431 self.ser.startBasicBlock(cond_block)
2432 self.ser.addInputTensor(iter)
2433 self.ser.addInputTensor(a)
2434 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002435 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
2436 cond_tens = self.ser.addOutput([], DType.BOOL)
2437 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002438
2439 # BODY block (input: a, acc, iter, output: a, acc, iter)
2440 # Note that local intermediate tensors need to be declared here for the outputs
2441 self.ser.startBasicBlock(body_block)
2442 self.ser.addInputTensor(iter)
2443 self.ser.addInputTensor(a)
2444 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002445 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
2446 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2447 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002448 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2449 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2450 self.ser.addOutputTensor(iter_body_out)
2451 self.ser.addOutputTensor(a)
2452 self.ser.addOutputTensor(acc_body_out)
2453
2454 return acc_out
2455
Kevin Cheng550ccc52021-03-03 11:21:43 -08002456 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01002457 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08002458 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002459
2460 try:
2461 op = self.TOSA_OP_LIST[opName]
2462 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002463 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002464
2465 # Initialize a new random number generator
2466 self.rng = np.random.default_rng(self.random_seed)
2467
Kevin Cheng550ccc52021-03-03 11:21:43 -08002468 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002469
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002470 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2471 default_test_rank_range = range(1, 5)
Matthew Haddone86fd342021-09-07 16:12:21 +01002472 if not shapeFilter:
2473 shapeFilter = [None]
2474
2475 # Generate the lists of arguments
2476 rmin, rmax = op["rank"]
2477 if rankFilter is not None:
2478 cleanRankFilter = []
2479 # Ensure rankFilter values are allowed by operator
2480 for rank in rankFilter:
2481 if rank >= rmin and rank <= rmax:
2482 cleanRankFilter.append(rank)
2483 rankFilter = cleanRankFilter
2484 elif rankFilter is None and shapeFilter[0] is None:
2485 cleanRankFilter = []
2486 # Ensure default behaviour is bounded by default range or by operator, whichever is smaller.
2487 rankRange = range(rmin, rmax + 1)
2488 for rank in rankRange:
2489 if rank >= min(default_test_rank_range) and rank <= max(default_test_rank_range):
2490 cleanRankFilter.append(rank)
2491 rankFilter = cleanRankFilter
2492 else:
2493 rankFilter = range(rmin, rmax + 1)
2494
2495 dtypes = op["types"]
2496 if dtypeFilter is not None:
2497 cleanDtypeFilter = []
2498 # Ensure filtered dtypes are allowed by operator
2499 for dtype in dtypeFilter:
2500 if dtype in dtypes:
2501 cleanDtypeFilter.append(dtype)
2502 dtypeFilter = cleanDtypeFilter
2503 else:
2504 dtypeFilter = dtypes
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002505
Eric Kunzee5e26762020-10-13 16:11:07 -07002506 # Test list consists of a tuple of:
2507 # (opName, testNameStr, dtype, shapeList, argumentsList)
2508 testList = []
2509
Matthew Haddon74567092021-07-16 15:38:20 +01002510 # Positive test loop
2511 if testType in ['positive', 'both']:
Matthew Haddone86fd342021-09-07 16:12:21 +01002512 for r in rankFilter:
Kevin Cheng1533b852021-09-01 12:51:58 -07002513 if opName.startswith("conv3d"):
2514 assert r == 5, "conv3d test must have input rank == 5"
Matthew Haddone86fd342021-09-07 16:12:21 +01002515 for t in dtypeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002516 # Create the placeholder and const tensors
2517 for shape in shapeFilter:
2518 # A None shape chooses a random shape of a given rank
Eric Kunzee5e26762020-10-13 16:11:07 -07002519
Matthew Haddon74567092021-07-16 15:38:20 +01002520 # Filter out by rank
2521 if shape is not None and len(shape) != r:
2522 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002523 self.setTargetShape(shape)
2524 shapeList = tgen_fcn(self, op, r)
Eric Kunzee5e26762020-10-13 16:11:07 -07002525
Matthew Haddon74567092021-07-16 15:38:20 +01002526 shapeStr = self.shapeStr(shapeList[0])
2527 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002528
Matthew Haddon74567092021-07-16 15:38:20 +01002529 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2530 argList = []
2531 if agen_fcn:
2532 argList = agen_fcn(self, opName, shapeList, t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002533 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002534 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002535
Matthew Haddon74567092021-07-16 15:38:20 +01002536 for argStr, args in argList:
2537 if argStr:
2538 testStr = "{}_{}_{}_{}".format(
2539 opName, shapeStr, typeStr, argStr
2540 )
2541 else:
2542 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
2543
Matthew Haddone86fd342021-09-07 16:12:21 +01002544 testList.append((opName, testStr, t, None, shapeList, args))
Matthew Haddon74567092021-07-16 15:38:20 +01002545
Matthew Haddonb724efc2021-08-25 16:40:29 +01002546 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2547 if "invalid_test_validators" in op:
2548 invalid_test_validators = op["invalid_test_validators"]
2549 clean_testList = []
2550 for test in testList:
2551 for validator_fcn in invalid_test_validators:
2552 remove_test = False
Matthew Haddone86fd342021-09-07 16:12:21 +01002553 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
Matthew Haddonb724efc2021-08-25 16:40:29 +01002554 remove_test = True
2555 if not remove_test:
2556 clean_testList.append(test)
2557 testList = clean_testList
2558
Matthew Haddone86fd342021-09-07 16:12:21 +01002559 # Store the original filters so they can be reused if required
2560 base_rankFilter = rankFilter
2561 base_dtypeFilter = dtypeFilter
2562 base_shapeFilter = shapeFilter
Matthew Haddon74567092021-07-16 15:38:20 +01002563 # Reset RNG so both positive and negative tests are reproducible
2564 self.resetRNG()
Matthew Haddone86fd342021-09-07 16:12:21 +01002565
Matthew Haddon74567092021-07-16 15:38:20 +01002566 # Negative test loop
Matthew Haddone86fd342021-09-07 16:12:21 +01002567 if testType in ['negative', 'both'] and "error_if_validators" in op:
2568 error_if_validators = op["error_if_validators"]
2569 for validator in error_if_validators:
Matthew Haddon848efb42021-09-09 12:30:53 +01002570 validator_info = validator(check=False, op=op)
Matthew Haddone86fd342021-09-07 16:12:21 +01002571 error_name = validator_info['error_name']
2572 error_arguments = validator_info['param_reqs']
2573
2574 #Set parameters as required
2575 if error_arguments['rank'] != None:
Matthew Haddon848efb42021-09-09 12:30:53 +01002576 rankFilter = error_arguments['rank']
Matthew Haddone86fd342021-09-07 16:12:21 +01002577 else:
2578 rankFilter = base_rankFilter
2579 if error_arguments['dtype'] != None:
2580 dtypeFilter = error_arguments['dtype']
2581 else:
2582 dtypeFilter = base_dtypeFilter
2583 if error_arguments['shape'] != None:
2584 shapes = error_arguments['shape']
2585 else:
2586 shapes = base_shapeFilter[:2] # Reduce number of shapes to keep test numbers small
2587
Matthew Haddon848efb42021-09-09 12:30:53 +01002588 for r in rankFilter:
Matthew Haddone86fd342021-09-07 16:12:21 +01002589 for t in dtypeFilter:
2590 # Create the placeholder and const tensors
2591 for shape in shapes:
2592 # A None shape chooses a random shape of a given rank
2593 # Filter out by rank
2594 if shape is not None and len(shape) != r:
2595 continue
2596 self.setTargetShape(shape)
2597 shapeList = tgen_fcn(self, op, r, error_name)
2598 shapeStr = self.shapeStr(shapeList[0])
2599 typeStr = self.typeStr(t)
2600 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2601 argList = []
2602 if agen_fcn:
2603 argList = agen_fcn(self, opName, shapeList, t, error_name)
2604 else:
2605 argList = [("", [])]
2606 for argStr, args in argList:
2607 if argStr:
2608 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2609 opName, error_name, shapeStr, typeStr, argStr
2610 )
2611 else:
2612 testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
2613 testList.append((opName, testStr, t, error_name, shapeList, args))
Eric Kunzee5e26762020-10-13 16:11:07 -07002614
2615 return testList
2616
Matthew Haddone86fd342021-09-07 16:12:21 +01002617
2618 def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07002619 try:
2620 op = self.TOSA_OP_LIST[opName]
2621 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002622 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002623
2624 # Create a serializer
2625 self.createSerializer(opName, testStr)
2626
Kevin Cheng550ccc52021-03-03 11:21:43 -08002627 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002628 if "error_if_validators" in op:
2629 error_if_validators = op["error_if_validators"]
2630 else:
2631 error_if_validators = None
2632
Kevin Cheng550ccc52021-03-03 11:21:43 -08002633 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002634 num_operands = pCount + cCount
2635
2636 if isinstance(dtype_or_dtypeList, list):
2637 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002638 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002639 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002640 else:
2641 dtypeList = [dtype_or_dtypeList] * (num_operands)
2642
Kevin Cheng93a16282021-08-31 16:14:03 -07002643 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002644 assert (
2645 len(shapeList) == num_operands
2646 ), "shapeList length {} must match number of operands {}".format(
2647 len(shapeList), num_operands
2648 )
2649 assert (
2650 len(dtypeList) == num_operands
2651 ), "dtypeList length {} must match number of operands {}".format(
2652 len(dtypeList), num_operands
2653 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002654
2655 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002656 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002657 except KeyError:
2658 qgen = None
2659
2660 # Build the random tensor operands and the test
2661 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002662
Jeremy Johnsonef509a42021-09-07 13:59:47 +01002663 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32:
2664 # Make sure the operation does not cause value saturation - where
2665 # the number wraps due to limited number of bits to store the answer
2666 assert (
2667 pCount == 2 and cCount == 0
2668 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
2669
2670 placeholders = []
2671 add = (op["op"] == Op.ADD)
2672 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2673 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2674 if add:
2675 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
2676 else:
2677 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
2678
2679 # Work out the saturation limits
2680 max_i32 = (1 << 31)-1
2681 min_i32 = -(1 << 31)
2682 max_arr = np.full(shapeList[1], max_i32)
2683 min_arr = np.full(shapeList[1], min_i32)
2684
2685 # Find how much values exceed the maximum/minimums
2686 sat_max_arr = np.maximum(res_arr - max_arr, 0)
2687 sat_min_arr = np.minimum(res_arr - min_arr, 0)
2688
2689 if not add:
2690 # Swap saturation values and negate values as we need to perform opposite operations
2691 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
2692
2693 # Create new array of unsaturated values by clipping values as needed
2694 b_unsat_arr = b_arr
2695 if (sat_max_arr != 0).any():
2696 # Clip values that cause saturation
2697 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
2698 # Reduce axes in unsaturated tensor to match original tensor
2699 for axis, dim in enumerate(b_arr.shape):
2700 if dim != b_unsat_arr.shape[axis]:
2701 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2702 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
2703
2704 if (sat_min_arr != 0).any():
2705 # Clip values that cause saturation
2706 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
2707 # Reduce axes in unsaturated tensor to match original tensor
2708 for axis, dim in enumerate(b_arr.shape):
2709 if dim != b_unsat_arr.shape[axis]:
2710 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2711 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
2712
2713 placeholders.append(
2714 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2715 )
2716 placeholders.append(
2717 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
2718 )
2719
2720 tens.extend(placeholders)
2721 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
2722 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002723 assert (
2724 pCount == 2 and cCount == 0
2725 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08002726
2727 placeholders = []
2728 for idx, shape in enumerate(shapeList[:]):
2729 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07002730 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002731 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002732 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002733 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002734 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002735 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
2736 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002737 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002738 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002739 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07002740 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002741
2742 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01002743 elif op["op"] == Op.SELECT:
2744 # Set datatype of condition tensor to boolean
2745 dtypeList[0] = DType.BOOL
2746 tens.extend(
2747 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2748 )
2749 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddon459443c2021-08-23 16:43:13 +01002750 elif op["op"] == Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002751 assert (
2752 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01002753 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002754
2755 placeholders = []
2756
Matthew Haddon459443c2021-08-23 16:43:13 +01002757 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002758 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07002759 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002760 while True:
2761 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2762 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2763
2764 if (divisor_arr == 0).any():
2765 continue
2766
Kevin Cheng47315e12021-05-13 17:41:28 -07002767 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002768 continue
2769
2770 break
2771
2772 placeholders.append(
2773 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
2774 )
2775 placeholders.append(
2776 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
2777 )
2778
2779 tens.extend(placeholders)
2780 elif op["op"] == Op.MUL:
2781 assert (
2782 pCount == 2 and cCount == 0
2783 ), "Op.MUL must have 2 placeholders, 0 consts"
2784
2785 if dtypeList[0] == DType.FLOAT:
2786 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
2787 else:
2788 placeholders = []
2789
2790 # Make sure multiply result in int32 range
2791 shift = testArgs[0]
2792 if dtypeList[0] == DType.INT8:
2793 num_bits = 8
2794 elif dtypeList[0] == DType.INT16:
2795 num_bits = 16
2796 elif dtypeList[0] == DType.INT32:
2797 num_bits = 32
2798 else:
2799 raise Exception("OpMul: invalid input dtype")
2800
2801 for idx, shape in enumerate(shapeList[:]):
2802 low = -(2 ** (num_bits - 1))
2803 high = (2 ** (num_bits - 1)) - 1
2804
2805 a_arr = np.int32(
2806 self.rng.integers(low=low, high=high, size=shapeList[0])
2807 )
2808 b_arr = np.int32(
2809 self.rng.integers(low=low, high=high, size=shapeList[1])
2810 )
2811
2812 i = 0
2813 while True:
2814
2815 a_arr_64 = a_arr.astype(np.int64)
2816 b_arr_64 = b_arr.astype(np.int64)
2817
2818 if shift > 0:
2819 rounding = 1 << (shift - 1)
2820 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
2821 else:
2822 result_arr = a_arr_64 * b_arr_64
2823
2824 if (result_arr > -(2 ** 31)).all() and (
2825 result_arr <= ((2 ** 31) - 1)
2826 ).all():
2827 break
2828
2829 i = i + 1
2830 a_arr = a_arr // 2
2831 b_arr = b_arr // 2
2832
2833 placeholders.append(
2834 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2835 )
2836 placeholders.append(
2837 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
2838 )
2839
2840 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01002841 elif op["op"] == Op.CONCAT:
2842 count = len(shapeList) - self.args.num_const_inputs_concat
2843 if count < 1:
2844 count = 1
2845 if self.args.num_const_inputs_concat == 0:
2846 count = len(shapeList)
2847
2848 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
2849 tens.extend(
2850 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
2851 )
2852 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002853 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002854 tens.extend(
2855 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2856 )
2857 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002858
2859 if qgen is not None:
Les Bell30e46802021-07-23 09:43:31 +01002860 qinfo = qgen(self, op, dtype_or_dtypeList)
Eric Kunzee5e26762020-10-13 16:11:07 -07002861 else:
2862 qinfo = None
2863
2864 try:
Matthew Haddone86fd342021-09-07 16:12:21 +01002865 if error_if_validators is None:
2866 if qinfo is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01002867 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
Matthew Haddone86fd342021-09-07 16:12:21 +01002868 else:
Matthew Haddon848efb42021-09-09 12:30:53 +01002869 resultName = build_fcn(self, op, *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07002870 else:
Matthew Haddone86fd342021-09-07 16:12:21 +01002871 if qinfo is not None:
Matthew Haddon848efb42021-09-09 12:30:53 +01002872 resultName = build_fcn(self, op, *tens, *testArgs, qinfo, error_if_validators, error_name)
Matthew Haddone86fd342021-09-07 16:12:21 +01002873 else:
Matthew Haddon848efb42021-09-09 12:30:53 +01002874 resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002875 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002876 print(
2877 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
2878 build_fcn, tens, testArgs
2879 )
2880 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002881 raise e
2882
Matthew Haddone86fd342021-09-07 16:12:21 +01002883 if resultName is None:
2884 print("Invalid ERROR_IF tests created")
2885
Eric Kunzee5e26762020-10-13 16:11:07 -07002886 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08002887 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07002888
2889 def createDynamicOpLists(self):
2890
2891 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002892 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002893
Kevin Cheng1533b852021-09-01 12:51:58 -07002894 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002895 testName = "conv2d_{}x{}".format(k[0], k[1])
2896 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2897 self.TOSA_OP_LIST[testName]["filter"] = k
2898 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002899
Kevin Cheng550ccc52021-03-03 11:21:43 -08002900 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2901 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2902 "depthwise_conv2d_TEMPLATE"
2903 ].copy()
2904 self.TOSA_OP_LIST[testName]["filter"] = k
2905 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002906
Kevin Cheng550ccc52021-03-03 11:21:43 -08002907 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2908 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2909 "transpose_conv2d_TEMPLATE"
2910 ].copy()
2911 self.TOSA_OP_LIST[testName]["filter"] = k
2912 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002913
Kevin Cheng1533b852021-09-01 12:51:58 -07002914 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2915 for k in KERNELS_3D:
2916 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2917 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2918 self.TOSA_OP_LIST[testName]["filter"] = k
2919 self.TOSA_OP_LIST[testName]["template"] = False
2920
Eric Kunzee5e26762020-10-13 16:11:07 -07002921 # Delete any templates after having created any dynamic ops
2922 # This is a two-pass operation because it's bad practice to delete
2923 # keys from dictionaries while iterating
2924 keyList = []
2925 for k in self.TOSA_OP_LIST:
2926 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002927 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07002928 keyList.append(k)
2929 continue
2930 except KeyError:
2931 pass
2932
2933 for k in keyList:
2934 del self.TOSA_OP_LIST[k]
2935
2936 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002937 """Fill in default fields for ops if they aren't already specified.
2938 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002939 for op in self.TOSA_OP_LIST:
2940
2941 # Required fields
2942 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002943 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002944 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002945 raise Exception(
2946 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2947 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002948
2949 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002950 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002951 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002952 raise Exception(
2953 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2954 op
2955 )
2956 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002957
2958 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002959 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002960 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002961 raise Exception(
2962 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2963 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002964
2965 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002966 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002967 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002968 raise Exception(
2969 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2970 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002971
2972 # Put in default rank range, if missing
2973 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002974 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002975 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002976 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002977
2978 # Tensor operator list
2979 # 'op': op name
2980 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002981 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2982 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002983 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2984 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002985 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002986
Kevin Cheng550ccc52021-03-03 11:21:43 -08002987 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2988 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002989
Kevin Cheng550ccc52021-03-03 11:21:43 -08002990 TYPE_BOOL = [DType.BOOL]
2991 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2992 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2993 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002994
Kevin Cheng550ccc52021-03-03 11:21:43 -08002995 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002996
Kevin Cheng1533b852021-09-01 12:51:58 -07002997 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002998 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002999 [DType.INT8, DType.INT8, DType.INT32],
3000 [DType.INT16, DType.INT8, DType.INT48],
3001 DType.FLOAT,
3002 ]
3003
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003004 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003005
3006 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003007 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003008 "argmax": {
3009 "op": Op.ARGMAX,
3010 "operands": (1, 0),
3011 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3012 "types": TYPE_NARROW_INT_FP,
3013 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003014 "avg_pool2d": {
3015 "op": Op.AVG_POOL2D,
3016 "operands": (1, 0),
3017 "rank": (4, 4),
3018 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3019 "qgen": TosaQuantGen.qgUnary,
3020 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003021 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003022 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003023 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003024 "conv2d_TEMPLATE": {
3025 "op": Op.CONV2D,
3026 "operands": (1, 2),
3027 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01003028 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003029 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003030 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003031 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003032 "template": True,
3033 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003034 # Templated operator. Filled in by createDynamicOpLists
3035 "conv3d_TEMPLATE": {
3036 "op": Op.CONV3D,
3037 "operands": (1, 2),
3038 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01003039 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07003040 "qgen": TosaQuantGen.qgConv,
3041 "types": TYPE_CONV,
3042 "template": True,
3043 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003044 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003045 "depthwise_conv2d_TEMPLATE": {
3046 "op": Op.DEPTHWISE_CONV2D,
3047 "operands": (1, 2),
3048 "filter": [1, 1],
3049 "rank": (4, 4),
3050 "build_fcn": (
3051 build_depthwise_conv2d,
3052 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01003053 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003054 ),
3055 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003056 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003057 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003058 "template": True,
3059 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003060 "fully_connected": {
3061 "op": Op.FULLY_CONNECTED,
3062 "operands": (1, 2),
3063 "rank": (2, 2),
3064 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
3065 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003066 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08003067 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003068 "matmul": {
3069 "op": Op.MATMUL,
3070 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003071 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08003072 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
3073 "qgen": TosaQuantGen.qgMatmul,
3074 "types": TYPE_NARROW_INT_FP,
3075 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003076 "max_pool2d": {
3077 "op": Op.MAX_POOL2D,
3078 "operands": (1, 0),
3079 "rank": (4, 4),
3080 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
3081 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003082 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08003083 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003084 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003085 "transpose_conv2d_TEMPLATE": {
3086 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003087 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003088 "rank": (4, 4),
3089 "build_fcn": (
3090 build_transpose_conv2d,
3091 TosaTensorGen.tgTransposeConv2D,
3092 TosaArgGen.agTransposeConv2D,
3093 ),
3094 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003095 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01003096 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003097 "template": True,
3098 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003099 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003100 "clamp": {
3101 "op": Op.CLAMP,
3102 "operands": (1, 0),
3103 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
3104 "types": TYPE_NARROW_INT_FP,
3105 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003106 "sigmoid": {
3107 "op": Op.SIGMOID,
3108 "operands": (1, 0),
3109 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
3110 "types": TYPE_FP,
3111 },
3112 "tanh": {
3113 "op": Op.TANH,
3114 "operands": (1, 0),
3115 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
3116 "types": TYPE_FP,
3117 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003118 # Elementwise Binary Operators
3119 "add": {
3120 "op": Op.ADD,
3121 "operands": (2, 0),
3122 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3123 "types": TYPE_FI32,
3124 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003125 "arithmetic_right_shift": {
3126 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3127 "operands": (2, 0),
3128 "build_fcn": (
3129 build_arithmetic_right_shift,
3130 TosaTensorGen.tgBroadcastFuzz,
3131 TosaArgGen.agArithmeticRightShift,
3132 ),
3133 "types": TYPE_INT,
3134 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003135 "bitwise_and": {
3136 "op": Op.BITWISE_AND,
3137 "operands": (2, 0),
3138 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3139 "types": TYPE_INT,
3140 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003141 "bitwise_or": {
3142 "op": Op.BITWISE_OR,
3143 "operands": (2, 0),
3144 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3145 "types": TYPE_INT,
3146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003147 "bitwise_xor": {
3148 "op": Op.BITWISE_XOR,
3149 "operands": (2, 0),
3150 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3151 "types": TYPE_INT,
3152 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003153 "intdiv": {
3154 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003155 "operands": (2, 0),
3156 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3157 "types": [DType.INT32],
3158 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003159 "logical_and": {
3160 "op": Op.LOGICAL_AND,
3161 "operands": (2, 0),
3162 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3163 "types": TYPE_BOOL,
3164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003165 "logical_left_shift": {
3166 "op": Op.LOGICAL_LEFT_SHIFT,
3167 "operands": (2, 0),
3168 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3169 "types": TYPE_INT,
3170 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003171 "logical_right_shift": {
3172 "op": Op.LOGICAL_RIGHT_SHIFT,
3173 "operands": (2, 0),
3174 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3175 "types": TYPE_INT,
3176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003177 "logical_or": {
3178 "op": Op.LOGICAL_OR,
3179 "operands": (2, 0),
3180 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3181 "types": TYPE_BOOL,
3182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003183 "logical_xor": {
3184 "op": Op.LOGICAL_XOR,
3185 "operands": (2, 0),
3186 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3187 "types": TYPE_BOOL,
3188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003189 "maximum": {
3190 "op": Op.MAXIMUM,
3191 "operands": (2, 0),
3192 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3193 "types": TYPE_FI32,
3194 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003195 "minimum": {
3196 "op": Op.MINIMUM,
3197 "operands": (2, 0),
3198 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3199 "types": TYPE_FI32,
3200 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003201 "mul": {
3202 "op": Op.MUL,
3203 "operands": (2, 0),
3204 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
3205 "types": TYPE_INT_FP,
3206 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003207 "pow": {
3208 "op": Op.POW,
3209 "operands": (2, 0),
3210 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
3211 "types": TYPE_FP,
3212 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003213 "sub": {
3214 "op": Op.SUB,
3215 "operands": (2, 0),
3216 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
3217 "types": TYPE_FI32,
3218 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003219 "table": {
3220 "op": Op.TABLE,
3221 # Use the automatic generation functions to create the input array
3222 # but create the table tensor in the build function, as it may be
3223 # a different type from the input
3224 "operands": (1, 0),
3225 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003226 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08003227 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003228 # Elementwise Unary operators
3229 "abs": {
3230 "op": Op.ABS,
3231 "operands": (1, 0),
3232 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3233 "types": TYPE_FI32,
3234 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003235 "bitwise_not": {
3236 "op": Op.BITWISE_NOT,
3237 "operands": (1, 0),
3238 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3239 "types": TYPE_INT,
3240 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003241 "ceil": {
3242 "op": Op.CEIL,
3243 "operands": (1, 0),
3244 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3245 "types": TYPE_FP,
3246 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003247 "clz": {
3248 "op": Op.CLZ,
3249 "operands": (1, 0),
3250 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3251 "types": [DType.INT32],
3252 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003253 "exp": {
3254 "op": Op.EXP,
3255 "operands": (1, 0),
3256 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3257 "types": TYPE_FP,
3258 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003259 "floor": {
3260 "op": Op.FLOOR,
3261 "operands": (1, 0),
3262 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3263 "types": TYPE_FP,
3264 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003265 "log": {
3266 "op": Op.LOG,
3267 "operands": (1, 0),
3268 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3269 "types": TYPE_FP,
3270 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003271 "logical_not": {
3272 "op": Op.LOGICAL_NOT,
3273 "operands": (1, 0),
3274 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3275 "types": TYPE_BOOL,
3276 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003277 "negate": {
3278 "op": Op.NEGATE,
3279 "operands": (1, 0),
3280 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3281 "qgen": TosaQuantGen.qgUnary,
3282 "types": TYPE_INT_FP,
3283 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003284 "reciprocal": {
3285 "op": Op.RECIPROCAL,
3286 "operands": (1, 0),
3287 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3288 "types": TYPE_FP,
3289 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003290 "rsqrt": {
3291 "op": Op.RSQRT,
3292 "operands": (1, 0),
3293 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3294 "types": TYPE_FP,
3295 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003296 # Elementwise Ternary operators
3297 "select": {
3298 "op": Op.SELECT,
3299 "operands": (3, 0),
3300 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
3301 "types": TYPE_FIB,
3302 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 # Comparison operators
3304 "equal": {
3305 "op": Op.EQUAL,
3306 "operands": (2, 0),
3307 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3308 "types": TYPE_FI32,
3309 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003310 "greater_equal": {
3311 "op": Op.GREATER_EQUAL,
3312 "operands": (2, 0),
3313 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3314 "types": TYPE_FI32,
3315 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003316 "greater": {
3317 "op": Op.GREATER,
3318 "operands": (2, 0),
3319 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
3320 "types": TYPE_FI32,
3321 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003322 # Reduction operators
3323 "reduce_all": {
3324 "op": Op.REDUCE_ALL,
3325 "operands": (1, 0),
3326 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3327 "types": TYPE_BOOL,
3328 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003329 "reduce_any": {
3330 "op": Op.REDUCE_ANY,
3331 "operands": (1, 0),
3332 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3333 "types": TYPE_BOOL,
3334 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003335 "reduce_max": {
3336 "op": Op.REDUCE_MAX,
3337 "operands": (1, 0),
3338 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3339 "types": TYPE_INT_FP,
3340 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003341 "reduce_min": {
3342 "op": Op.REDUCE_MAX,
3343 "operands": (1, 0),
3344 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3345 "types": TYPE_INT_FP,
3346 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003347 "reduce_product": {
3348 "op": Op.REDUCE_PRODUCT,
3349 "operands": (1, 0),
3350 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3351 "types": TYPE_FP,
3352 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 "reduce_sum": {
3354 "op": Op.REDUCE_SUM,
3355 "operands": (1, 0),
3356 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3357 "types": TYPE_FI32,
3358 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003359 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003360 "concat": {
3361 "op": Op.CONCAT,
3362 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01003363 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003364 "types": TYPE_FIB,
3365 },
3366 "pad": {
3367 "op": Op.PAD,
3368 "operands": (1, 0),
3369 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
3370 "qgen": TosaQuantGen.qgPad,
3371 "types": TYPE_FIB,
3372 },
3373 "reshape": {
3374 "op": Op.RESHAPE,
3375 "operands": (1, 0),
3376 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
3377 "types": TYPE_FIB,
3378 },
3379 "reverse": {
3380 "op": Op.REVERSE,
3381 "operands": (1, 0),
3382 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
3383 "types": TYPE_FIB,
3384 },
3385 "slice": {
3386 "op": Op.SLICE,
3387 "operands": (1, 0),
3388 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
3389 "types": TYPE_FIB,
3390 },
3391 "tile": {
3392 "op": Op.TILE,
3393 "operands": (1, 0),
3394 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
3395 "types": TYPE_FIB,
3396 },
3397 "transpose": {
3398 "op": Op.TRANSPOSE,
3399 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003400 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003401 "build_fcn": (
3402 build_transpose,
3403 TosaTensorGen.tgBasic,
3404 TosaArgGen.agTranspose,
3405 ),
3406 "types": TYPE_FIB,
3407 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003408 # Data nodes
3409 "const": {
3410 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003411 "operands": (0, 1),
3412 "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 "types": TYPE_FIB,
3414 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003415 "identity": {
3416 "op": Op.IDENTITY,
3417 "operands": (1, 0),
3418 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
3419 "types": TYPE_FIB,
3420 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003421 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003422 "gather": {
3423 "op": Op.GATHER,
3424 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3425 "operands": (1, 0),
3426 "rank": (3, 3),
3427 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
3428 "types": TYPE_INT_FP,
3429 },
3430 "scatter": {
3431 "op": Op.SCATTER,
3432 # Only specify 'values_in' tensor here.
3433 #'indices' and 'input' are generated in op building stage
3434 "operands": (2, 0),
3435 "rank": (3, 3),
3436 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
3437 "types": TYPE_INT_FP,
3438 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003439 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003440 "resize": {
3441 "op": Op.RESIZE,
3442 "operands": (1, 0),
3443 "rank": (4, 4),
3444 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
3445 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddone86fd342021-09-07 16:12:21 +01003446 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
3447 "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
3448 TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
Matthew Haddon848efb42021-09-09 12:30:53 +01003449 TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
3450 TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003451 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003452 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003453 "cast": {
3454 "op": Op.CAST,
3455 "operands": (1, 0),
3456 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
3457 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
3458 },
3459 "rescale": {
3460 "op": Op.RESCALE,
3461 "operands": (1, 0),
3462 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003463 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08003464 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003465 # Custom
3466 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003467 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003468 # Two varients of cond_if, one that generates one of two constant tensors (no
3469 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3470 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003471 "cond_if_const": {
3472 "op": Op.COND_IF,
3473 "operands": (0, 2),
3474 "build_fcn": (
3475 build_cond_if_const,
3476 TosaTensorGen.tgBasic,
3477 TosaArgGen.agCondIf,
3478 ),
3479 "types": [DType.BOOL],
3480 },
3481 "cond_if_binary": {
3482 "op": Op.COND_IF,
3483 "operands": (2, 0),
3484 "build_fcn": (
3485 build_cond_if_binary,
3486 TosaTensorGen.tgBasic,
3487 TosaArgGen.agCondIf,
3488 ),
3489 "types": TYPE_FI32,
3490 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003491 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003492 "while_loop": {
3493 "op": Op.WHILE_LOOP,
3494 "operands": (0, 1),
3495 "build_fcn": (
3496 build_while_loop,
3497 TosaTensorGen.tgBasic,
3498 TosaArgGen.agWhileLoop,
3499 ),
3500 "types": [DType.INT32],
3501 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003502 }
3503
Kevin Cheng550ccc52021-03-03 11:21:43 -08003504
Eric Kunzee5e26762020-10-13 16:11:07 -07003505class OutputShaper:
3506 # Methods in this class compute the expected output shape and datatype
3507 # for common classes of operations
3508 def __init__(self):
3509 pass
3510
3511 # These methods return arguments that can be used for
3512 # creating a new output tensor
3513 @staticmethod
3514 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003515 assert len(a.shape) == len(b.shape)
3516 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003517
3518 shape = []
3519 for i in range(len(a.shape)):
3520 if a.shape[i] == 1:
3521 shape.append(b.shape[i])
3522 else:
3523 shape.append(a.shape[i])
3524
Kevin Cheng550ccc52021-03-03 11:21:43 -08003525 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003526
3527 @staticmethod
3528 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003529 assert len(a.shape) == len(b.shape)
3530 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003531
3532 shape = []
3533 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003534 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003535 shape.append(a.shape[i])
3536
Kevin Cheng550ccc52021-03-03 11:21:43 -08003537 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003538
3539 @staticmethod
3540 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003541 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003542
3543 @staticmethod
3544 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003545 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
3546 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003547
3548 shape = []
3549 for i in range(len(a.shape)):
3550 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3551
Kevin Cheng550ccc52021-03-03 11:21:43 -08003552 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003553
3554 @staticmethod
3555 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003556 assert len(a.shape) == len(b.shape)
3557 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003558
3559 # Do broadcast
3560 shape = []
3561 for i in range(len(a.shape)):
3562 if a.shape[i] == 1:
3563 shape.append(b.shape[i])
3564 else:
3565 shape.append(a.shape[i])
3566
3567 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08003568 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07003569
3570 @staticmethod
3571 def reduceOp(ser, a, axis):
3572
3573 shape = a.shape.copy()
3574
3575 shape[axis] = 1
3576
Kevin Cheng550ccc52021-03-03 11:21:43 -08003577 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003578
3579 @staticmethod
3580 def argmaxOp(ser, a, axis):
3581 shape = a.shape.copy()
3582 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003583 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07003584
3585 @staticmethod
3586 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
3587
3588 # IFM: NHWC
3589 # Filter: OHWI
3590 # OFM: NHWC
3591
3592 if len(padding) == 2:
3593 # Expand padding to 4 parameters in the case of transpose_conv2d
3594 # From H,W to T,B,L,R
3595 padding = [padding[0], padding[0], padding[1], padding[1]]
3596
Kevin Cheng550ccc52021-03-03 11:21:43 -08003597 h = (
3598 ifm.shape[1]
3599 - filter.shape[1]
3600 - (filter.shape[1] - 1) * (dilations[0] - 1)
3601 + padding[0]
3602 + padding[1]
3603 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003604
Kevin Cheng550ccc52021-03-03 11:21:43 -08003605 w = (
3606 ifm.shape[2]
3607 - filter.shape[2]
3608 - (filter.shape[2] - 1) * (dilations[1] - 1)
3609 + padding[2]
3610 + padding[3]
3611 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003612
Eric Kunzee5e26762020-10-13 16:11:07 -07003613 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3614
Kevin Cheng3a478572021-01-22 17:21:02 -08003615 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003616 out_dtype = DType.INT32
3617 elif ifm.dtype == DType.INT16:
3618 out_dtype = DType.INT48
3619 elif ifm.dtype == DType.FLOAT:
3620 out_dtype = DType.FLOAT
3621 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003622 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003623
Kevin Cheng550ccc52021-03-03 11:21:43 -08003624 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003625
3626 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07003627 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
3628
3629 # IFM: NDHWC
3630 # Filter: ODHWI
3631 # OFM: NDHWC
3632
3633 d = (
3634 ifm.shape[1]
3635 - filter.shape[1]
3636 - (filter.shape[1] - 1) * (dilations[0] - 1)
3637 + padding[0]
3638 + padding[1]
3639 ) // strides[0] + 1
3640
3641 h = (
3642 ifm.shape[2]
3643 - filter.shape[2]
3644 - (filter.shape[2] - 1) * (dilations[1] - 1)
3645 + padding[2]
3646 + padding[3]
3647 ) // strides[1] + 1
3648
3649 w = (
3650 ifm.shape[3]
3651 - filter.shape[3]
3652 - (filter.shape[3] - 1) * (dilations[2] - 1)
3653 + padding[4]
3654 + padding[5]
3655 ) // strides[2] + 1
3656
3657 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
3658
3659 if ifm.dtype == DType.INT8:
3660 out_dtype = DType.INT32
3661 elif ifm.dtype == DType.INT16:
3662 out_dtype = DType.INT48
3663 elif ifm.dtype == DType.FLOAT:
3664 out_dtype = DType.FLOAT
3665 else:
3666 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
3667
3668 return ser.addOutput(ofm_shape, out_dtype)
3669
3670 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07003671 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
3672 # IFM: NHWC
3673 # Filter: HWCM
3674 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08003675 h = (
3676 ifm.shape[1]
3677 - filter.shape[0]
3678 - (filter.shape[0] - 1) * (dilations[0] - 1)
3679 + padding[0]
3680 + padding[1]
3681 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003682
Kevin Cheng550ccc52021-03-03 11:21:43 -08003683 w = (
3684 ifm.shape[2]
3685 - filter.shape[1]
3686 - (filter.shape[1] - 1) * (dilations[1] - 1)
3687 + padding[2]
3688 + padding[3]
3689 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003690
Eric Kunzee5e26762020-10-13 16:11:07 -07003691 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
3692
Kevin Cheng3a478572021-01-22 17:21:02 -08003693 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003694 out_dtype = DType.INT32
3695 elif ifm.dtype == DType.INT16:
3696 out_dtype = DType.INT48
3697 elif ifm.dtype == DType.FLOAT:
3698 out_dtype = DType.FLOAT
3699 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003700 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003701
Kevin Cheng550ccc52021-03-03 11:21:43 -08003702 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003703
3704 @staticmethod
3705 def pool2dOp(ser, ifm, kernel, stride, pad):
3706 # input: NHWC
3707 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
3708 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
3709
Eric Kunzee5e26762020-10-13 16:11:07 -07003710 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003711 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003712
3713 @staticmethod
3714 def fullyConnectedOp(ser, input, filter):
3715 # input: N, IC
3716 # filter: OC, IC
3717 # output: N, OC
3718
3719 output_shape = [input.shape[0], filter.shape[0]]
3720
Kevin Cheng3a478572021-01-22 17:21:02 -08003721 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003722 out_dtype = DType.INT32
3723 elif input.dtype == DType.INT16:
3724 out_dtype = DType.INT48
3725 elif input.dtype == DType.FLOAT:
3726 out_dtype = DType.FLOAT
3727 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003728 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003729
Kevin Cheng550ccc52021-03-03 11:21:43 -08003730 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003731
3732 @staticmethod
3733 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07003734 # a: N, H, C
3735 # b: N, C, W
3736 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07003737
Kevin Cheng2d60f002021-06-09 14:18:32 -07003738 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003739
Kevin Cheng3a478572021-01-22 17:21:02 -08003740 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003741 out_dtype = DType.INT32
3742 elif a.dtype == DType.INT16:
3743 out_dtype = DType.INT48
3744 elif a.dtype == DType.FLOAT:
3745 out_dtype = DType.FLOAT
3746 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003747 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003748
Kevin Cheng550ccc52021-03-03 11:21:43 -08003749 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003750
3751 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01003752 def concatOp(ser, axis, *a):
3753 input1 = a[0]
3754 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07003755
Matthew Haddon818ab902021-07-27 09:12:49 +01003756 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07003757
Matthew Haddon818ab902021-07-27 09:12:49 +01003758 output_shape[axis] = input1.shape[axis]
3759
3760 for tensor in remaining_inputs:
3761 output_shape[axis] += tensor.shape[axis]
3762
3763 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003764
3765 @staticmethod
3766 def padOp(ser, a, padding):
3767
3768 output_shape = a.shape.copy()
3769
3770 for i in range(len(output_shape)):
3771 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
3772
Kevin Cheng550ccc52021-03-03 11:21:43 -08003773 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003774
3775 @staticmethod
3776 def reshapeOp(ser, a, shape):
3777 output_shape = shape.copy()
3778
3779 totalElements = 1
3780 for i in a.shape:
3781 totalElements *= i
3782
3783 # If there are any -1 elements, figure out what that dimension must be
3784 totalOutputElements = 1
3785 for i in output_shape:
3786 if i != -1:
3787 totalOutputElements *= i
3788
3789 # And fill it in
3790 for i in range(len(output_shape)):
3791 if output_shape[i] == -1:
3792 output_shape[i] = totalElements // totalOutputElements
3793
Kevin Cheng550ccc52021-03-03 11:21:43 -08003794 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003795
3796 @staticmethod
3797 def sliceOp(ser, a, begin, size):
3798
3799 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003800 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003801
3802 @staticmethod
3803 def tileOp(ser, a, multiples):
3804
3805 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003806 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003807
3808 for i in range(len(output_shape)):
3809 output_shape[i] = a.shape[i] * multiples[i]
3810
Kevin Cheng550ccc52021-03-03 11:21:43 -08003811 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003812
3813 @staticmethod
3814 def transposeOp(ser, a, perms):
3815 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003816 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003817
3818 for i in range(len(output_shape)):
3819 output_shape[i] = a.shape[perms[i]]
3820
Kevin Cheng550ccc52021-03-03 11:21:43 -08003821 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003822
3823 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08003824 def gatherOp(ser, values, indices):
3825 assert len(values.shape) == 3
3826 assert len(indices.shape) == 2
3827 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07003828
Kevin Cheng77d0f762020-11-24 10:26:32 -08003829 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
3830
Kevin Cheng550ccc52021-03-03 11:21:43 -08003831 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003832
3833 @staticmethod
3834 def scatterOp(ser, values_in, indices, input):
3835 assert len(values_in.shape) == 3
3836 assert len(indices.shape) == 2
3837 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08003838 assert values_in.shape[0] == indices.shape[0] # N
3839 assert input.shape[1] == indices.shape[1] # W
3840 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08003841
3842 output_shape = values_in.shape
3843
Kevin Cheng550ccc52021-03-03 11:21:43 -08003844 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003845
3846 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003847 def tableOp(ser, input, table_dtype):
3848 # Same shape as the input, but dtype dependent on table dtype
3849 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
3850 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
3851 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003852
3853 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08003854 def resizeOp(
3855 ser,
3856 input,
3857 mode,
3858 stride,
3859 offset,
3860 shift,
3861 stride_fp,
3862 offset_fp,
3863 output_dims,
3864 input_dtype,
3865 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01003866 error_name = None
Kevin Cheng550ccc52021-03-03 11:21:43 -08003867 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01003868 if error_name == ErrorIf.WrongRank:
3869 output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
3870 else:
3871 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003872
Kevin Cheng550ccc52021-03-03 11:21:43 -08003873 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003874
3875 @staticmethod
3876 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003877 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003878
3879 @staticmethod
3880 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08003881 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003882 out_dtype = DType.INT32
3883 elif ifm.dtype == DType.INT16:
3884 out_dtype = DType.INT48
3885 elif ifm.dtype == DType.FLOAT:
3886 out_dtype = DType.FLOAT
3887 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003888 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003889
Kevin Cheng550ccc52021-03-03 11:21:43 -08003890 return ser.addOutput(output_shape, out_dtype)