blob: 1f2800007e8d5949b970477b4e6e5abb0b163a5b [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
Kevin Chengacb550f2021-06-29 15:32:19 -070035from tosa_ref_run import TosaReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng550ccc52021-03-03 11:21:43 -080037# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
38parent_dir = os.path.dirname(os.path.realpath(__file__))
39sys.path.append(
40 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
41)
Eric Kunzee5e26762020-10-13 16:11:07 -070042import tosa_serializer as ts
43from tosa_serializer import *
44import tosa
45
46# Convenience variables to the flatc-generated types that should be enums, but aren't
47DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080048Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070049ResizeMode = tosa.ResizeMode.ResizeMode()
50
Kevin Cheng550ccc52021-03-03 11:21:43 -080051
Eric Kunzee5e26762020-10-13 16:11:07 -070052class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
54
Eric Kunzee5e26762020-10-13 16:11:07 -070055 def __init__(self):
56 pass
57
58 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010059 def getQinfo(testGen, dtype):
60 if dtype == DType.INT8:
61 return testGen.randInt(-128, 128)
62 if dtype == DType.UINT8:
63 return testGen.randInt(0, 256)
64 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070065
66 @staticmethod
67 def qgUnary(testGen, op, dtype):
68 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070069 qinfo.UnaryQuantInfo(
70 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
71 )
Eric Kunzee5e26762020-10-13 16:11:07 -070072 return qinfo
73
74 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010075 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070076 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010077 if isinstance(dtype_or_dtypeList, list):
78 # a list of [input, weights, accumulator] dtypes
79 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070080 else:
Les Bell30e46802021-07-23 09:43:31 +010081 # an int, [input, weights, accumulator] dtypes are the same
82 dtypeList = [dtype_or_dtypeList] * 3
83 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
84 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
85 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070086 return qinfo
87
88 @staticmethod
89 def qgMatmul(testGen, op, dtype):
90 qinfo = ts.TosaSerializerQuantInfo()
Kevin Chengacb550f2021-06-29 15:32:19 -070091 qinfo.MatMulQuantInfo(
92 TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
93 )
Eric Kunzee5e26762020-10-13 16:11:07 -070094 return qinfo
95
96 @staticmethod
97 def qgPad(testGen, op, dtype):
98 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010099 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 return qinfo
101
102 @staticmethod
103 def computeMultiplierAndShift(scaleFp, scale32):
104 # Derived from computeMultiplierAndShiftTosaScale32
105 # Provide a floating-point scaling factor and the scale32 parameter
106 # to compute the multiplier and shift
107
108 if scale32:
109 scaleBits = 31
110 else:
111 scaleBits = 15
112
113 m, shift = math.frexp(scaleFp)
114
115 if scaleFp < 0.0:
116 m = -m
117
118 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800119 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 if multiplier == (1 << scaleBits):
122 multiplier = multiplier // 2
123 shift = shift + 1
124
125 shift = (-shift) + scaleBits
Matthew Haddonb724efc2021-08-25 16:40:29 +0100126 #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
127
128 # Adjust multiplier such that shift is in allowed value range.
129 if shift == 0:
130 multiplier = multiplier // 4
131 shift = shift + 2
132 elif shift == 1:
133 multiplier = multiplier // 2
134 shift = shift + 1
135 elif shift == 63:
136 multiplier = multiplier * 2
137 shift = shift - 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700138
Kevin Cheng550ccc52021-03-03 11:21:43 -0800139 assert multiplier <= (1 << scaleBits)
Matthew Haddonb724efc2021-08-25 16:40:29 +0100140 assert shift >= 2 and shift <= 62
Eric Kunzee5e26762020-10-13 16:11:07 -0700141
142 return multiplier, shift
143
144
Kevin Cheng550ccc52021-03-03 11:21:43 -0800145class TosaTensorGen:
146 """Tensor generators create a shape list for the placeholder and const tensor
147 data operands for the operator. The actual random data is generated separately for each test."""
148
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 def __init__(self):
150 pass
151
152 @staticmethod
153 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700155 shape = testGen.makeShape(rank)
156
157 shape_list = []
158 for i in range(pl + const):
159 shape_list.append(shape.copy())
160
161 return shape_list
162
163 @staticmethod
164 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800165 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700166
Kevin Cheng550ccc52021-03-03 11:21:43 -0800167 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700168
169 shape = testGen.makeShape(rank)
170
171 # Constrict the batch size?
172 if testGen.args.max_batch_size:
173 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
174
175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
179 return shape_list
180
181 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800182 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800183 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800184
Kevin Cheng550ccc52021-03-03 11:21:43 -0800185 assert pl == 2
186 assert const == 0
187 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800188
189 values_in_shape = testGen.makeShape(rank)
190
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100191 # ignore max batch size if target shape is set
192 if testGen.args.max_batch_size and not testGen.args.target_shapes:
Kevin Cheng77d0f762020-11-24 10:26:32 -0800193 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
194
Kevin Cheng550ccc52021-03-03 11:21:43 -0800195 W = testGen.randInt(
196 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
197 )
Matthew Haddon4b2881a2021-08-24 14:25:43 +0100198 # Constrict W if one dimension is too large to keep tensor size reasonable
199 if max(values_in_shape) > 5000:
200 W = testGen.randInt(0, 16)
201
Kevin Cheng77d0f762020-11-24 10:26:32 -0800202 input_shape = [values_in_shape[0], W, values_in_shape[2]]
203
204 shape_list = []
205 shape_list.append(values_in_shape.copy())
206 shape_list.append(input_shape.copy())
207
208 return shape_list
209
210 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700211 def tgBroadcastFuzz(testGen, op, rank):
212 shape = testGen.makeShape(rank)
213
Kevin Cheng550ccc52021-03-03 11:21:43 -0800214 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 shape_list = []
217
218 # Choose one of the inputs to broadcast
219 bcast_idx = testGen.randInt(0, pl + const)
220 for i in range(pl + const):
221 shape_bcast = shape.copy()
222
223 # If the chosen input, pick a random index to broadcast
224 if i == bcast_idx:
225 fuzz_idx = testGen.randInt(0, rank)
226 shape_bcast[fuzz_idx] = 1
227
228 shape_list.append(shape_bcast)
229
230 return shape_list
231
232 @staticmethod
233 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800234 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700235
Kevin Cheng550ccc52021-03-03 11:21:43 -0800236 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700237
238 # IFM dimensions are NHWC
239 ifm_shape = testGen.makeShape(rank)
240
241 # Constrict the batch size?
242 if testGen.args.max_batch_size:
243 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
244
245 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800246 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700247
248 # Generate a random OFM depth
249 ofm_depth = testGen.makeShape(1)[0]
250
251 # The filter dimensions are OHWI
252 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
253
254 # The bias is OC
255 bias_shape = np.asarray([ofm_depth])
256
257 return [ifm_shape, filter_shape, bias_shape]
258
259 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -0700260 def tgConv3D(testGen, op, rank):
261 pl, const = op["operands"]
262
263 assert rank == 5
264
265 # IFM dimensions are NDHWC
266 ifm_shape = testGen.makeShape(rank)
267
268 # Constrict the batch size?
269 if testGen.args.max_batch_size:
270 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
271
272 # Get the filter depth/height/width from the operator parameters
273 filter_dhw = op["filter"]
274
275 # Generate a random OFM channel
276 ofm_channel = testGen.makeShape(1)[0]
277
278 # The filter dimensions are ODHWI
279 filter_shape = np.asarray(
280 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
281 )
282
283 # The bias is OC
284 bias_shape = np.asarray([ofm_channel])
285
286 return [ifm_shape, filter_shape, bias_shape]
287
288 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800290 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700291
Kevin Cheng550ccc52021-03-03 11:21:43 -0800292 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700293
294 # IFM dimensions are NHWC
295 ifm_shape = testGen.makeShape(rank)
296
297 # Constrict the batch size?
298 if testGen.args.max_batch_size:
299 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
300
301 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
304 # Generate a random OFM depth
305 ofm_depth = testGen.makeShape(1)[0]
306
307 # The filter dimensions are OHWI
308 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
309
Kevin Cheng989cb052021-04-28 16:29:44 -0700310 # The bias is OC
311 bias_shape = np.asarray([ofm_depth])
312
313 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700314
315 @staticmethod
316 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800317 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700318
Kevin Cheng550ccc52021-03-03 11:21:43 -0800319 assert rank == 4
320 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700321
322 # IFM dimensions are NHWC
323 ifm_shape = testGen.makeShape(rank)
324
325 # Constrict the batch size?
326 if testGen.args.max_batch_size:
327 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
328
329 # Get the filter height/width from the operator parameters
330 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800331 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700332
333 # Generate a random OFM depth, but don't let it get too big because
334 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800335 filter_m = (
336 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
337 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700338
339 # The filter dimensions are HWCM
340 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
341
342 # The bias is M * C
343 bias_shape = np.asarray([ifm_shape[3] * filter_m])
344
345 return [ifm_shape, filter_shape, bias_shape]
346
347 @staticmethod
348 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800349 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700350
Kevin Cheng550ccc52021-03-03 11:21:43 -0800351 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700352
353 input_shape = testGen.makeShape(rank)
Kevin Chengacb550f2021-06-29 15:32:19 -0700354 filter_oc = testGen.rng.integers(
355 low=testGen.args.tensor_shape_range[0],
356 high=testGen.args.tensor_shape_range[1],
357 size=1,
358 )[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700359 filter_shape = np.asarray([filter_oc, input_shape[1]])
360
361 bias_shape = np.asarray([filter_oc])
362
363 return [input_shape, filter_shape, bias_shape]
364
365 @staticmethod
366 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800367 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700368
Kevin Cheng2d60f002021-06-09 14:18:32 -0700369 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800370 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700371
372 a_shape = testGen.makeShape(rank)
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100373 # Get a random number for b_oc even if target shape is defined
374 b_oc = np.int32(
375 testGen.rng.integers(
376 low=testGen.args.tensor_shape_range[0],
377 high=testGen.args.tensor_shape_range[1],
378 size=1,
379 )
380 )[0]
381 # If N or H is large let b_oc be 1 to reduce output tensor size
382 if max(a_shape) > 1000:
383 b_oc = 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
Matthew Haddon68e7aee2021-08-16 11:20:25 +0100385 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700386 return [a_shape, b_shape]
387
Matthew Haddon818ab902021-07-27 09:12:49 +0100388 @staticmethod
389 def tgConcat(testGen, opName, rank):
390 pl, const = opName["operands"]
391 shape = testGen.makeShape(rank)
392
393 # Create extra tensors to concat.
394 # Take into account value of pl when getting maximum number of concats
395 num_tensors = testGen.randInt(0, 4)
396 shape_list = []
397 for i in range(pl + const + num_tensors):
398 shape_list.append(shape.copy())
399
400 return shape_list
401
402 @staticmethod
403 def tgConcatConstInput(testGen, shapeList, axis):
404 # Split concat shape along axis to allow for multiple const inputs
405 # without making too many large tensors
406 shape = shapeList[0]
407 if len(shapeList) == 2 or shape[axis] < len(shapeList):
408 return shapeList
409
410 new_shapeList = [shape.copy()]
411 length_on_axis = shape[axis]
412 remaining_length = length_on_axis
Kevin Cheng93a16282021-08-31 16:14:03 -0700413 for i in range(len(shapeList) - 2):
Matthew Haddon818ab902021-07-27 09:12:49 +0100414 # Calculate split on axis and remaining value
415 split_shape_val = int(shape[axis] / 2)
416 remaining_length = remaining_length - split_shape_val
417
418 # Append new shape, and set remaining shape
419 shape[axis] = split_shape_val
420 new_shapeList.append(shape.copy())
421 shape[axis] = remaining_length
422 if i == len(shapeList) - 3:
423 new_shapeList.append(shape.copy())
424
425 return new_shapeList
426
427
Eric Kunzee5e26762020-10-13 16:11:07 -0700428class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800429 """Argument generators create exhaustive or random lists of attributes for operators that take
430 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
431 tuples where the descriptive_name is appended to the test name and the arglist is expanded
432 as arguments to the operator build function."""
433
Eric Kunzee5e26762020-10-13 16:11:07 -0700434 def __init__(self):
435 pass
436
437 @staticmethod
438 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800439 """A trivial argument generator for operators that don't take any
440 non-tensor arguments"""
441 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700442
443 @staticmethod
444 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800445 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700446 axes = []
447
448 shape = shapeList[0]
449
450 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100451 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700452 return axes
453
454 @staticmethod
Les Bell7aa69f42021-09-20 10:44:07 +0100455 def agConv(testGen, opName, shapeList, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700456 arg_list = []
457
458 ifm_shape = shapeList[0]
459 filter_shape = shapeList[1]
Les Bell7aa69f42021-09-20 10:44:07 +0100460 # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
461 k = [int(x) for x in opName.split("_")[-1].split("x")]
Eric Kunzee5e26762020-10-13 16:11:07 -0700462
Les Bell7aa69f42021-09-20 10:44:07 +0100463 # Check the rank
464 rank = 5 if opName.startswith("conv3d") else 4
465 assert len(ifm_shape) == rank
466 assert len(filter_shape) == rank
Eric Kunzee5e26762020-10-13 16:11:07 -0700467
Les Bell7aa69f42021-09-20 10:44:07 +0100468 # kernel rank omits batch and channels
469 k_rank = rank - 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700470
Les Bell7aa69f42021-09-20 10:44:07 +0100471 # Generate comprehensive argument lists
472 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
473 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
474 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
475 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
476 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
477 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700478
Les Bell7aa69f42021-09-20 10:44:07 +0100479 # add some oversize argument values
480 if max(ifm_shape) < 64:
481 bigPadding = 9
482 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
483 bigStride = 8
484 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
485 bigDilation = 7
486 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
Les Bellf414b3c2021-09-06 11:29:46 +0100487
488 # There are too many parameter combinations, so generate them sparsely
Les Bell7aa69f42021-09-20 10:44:07 +0100489 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
490 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
491 if sparsity < 13:
492 sparsity = 1
493 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
494 sparsity += 1
Les Bellf414b3c2021-09-06 11:29:46 +0100495 n = 0
Les Bell7aa69f42021-09-20 10:44:07 +0100496 for s in sorted(list(strides)):
497 for p in sorted(list(paddings)):
498 for d in sorted(list(dilations)):
499 if (n % sparsity == 0
500 # padding must not exceed the kernel size ?
501 # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
502 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
503 # the padded shape must exceed the kernel size
504 and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
505 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
506 # the padded shape must exceed the dilation
507 and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
508 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
509 ):
Les Bellf414b3c2021-09-06 11:29:46 +0100510 arg_list.append(
511 (
512 "st{}_pad{}_dilat{}".format(
513 "".join([str(x) for x in s]),
514 "".join([str(x) for x in p]),
515 "".join([str(x) for x in d]),
516 ),
517 [s, p, d],
518 )
519 )
520 n += 1
521
Kevin Cheng1533b852021-09-01 12:51:58 -0700522 return arg_list
523
524 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 def agTransposeConv2D(testGen, opName, shapeList, dtype):
526 arg_list = []
527
528 ifm_shape = shapeList[0]
529 filter_shape = shapeList[1]
530
531 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800532 assert len(ifm_shape) == 4
533 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700534
Les Bell7aa69f42021-09-20 10:44:07 +0100535 # Generate comprehensive argument lists
536 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
537 paddings = {x for x in itertools.product(*([p_vals] * 2))}
538 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
539 strides = {x for x in itertools.product(*([s_vals] * 2))}
540 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
541 dilations = {x for x in itertools.product(*([d_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700542
Les Bell7aa69f42021-09-20 10:44:07 +0100543 # add some oversize argument values
544 if max(ifm_shape) < 64:
545 bigPadding = 9
546 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
547 bigStride = 8
548 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
549 bigDilation = 7
550 dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700551
Les Bell7aa69f42021-09-20 10:44:07 +0100552 # There are too many parameter combinations, so generate them sparsely
553 # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
554 sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
555 if sparsity < 13:
556 sparsity = 1
557 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
558 sparsity += 1
559 n = 0
560 for s in sorted(list(strides)):
561 for p in sorted(list(paddings)):
562 for d in sorted(list(dilations)):
563 if n % sparsity == 0:
564 # Determine the output shape
565 oh = (
566 ifm_shape[1]
567 - filter_shape[1]
568 - (filter_shape[1] - 1) * (d[0] - 1)
569 + 2 * p[0]
570 ) // s[0] + 1
571 ow = (
572 ifm_shape[2]
573 - filter_shape[2]
574 - (filter_shape[2] - 1) * (d[1] - 1)
575 + 2 * p[1]
576 ) // s[1] + 1
577 os = [ifm_shape[0], oh, ow, filter_shape[0]]
578 arg_list.append(
579 (
580 "st{}_pad{}_dilat{}_os{}".format(
581 "".join([str(x) for x in s]),
582 "".join([str(x) for x in p]),
583 "".join([str(x) for x in d]),
584 "x".join([str(x) for x in os]),
585 ),
586 [s, p, d, os],
587 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800588 )
Les Bell7aa69f42021-09-20 10:44:07 +0100589 n += 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700590
591 return arg_list
592
593 @staticmethod
594 def agPad(testGen, opName, shapeList, dtype):
595 arg_list = []
596 rank = len(shapeList[0])
597
Les Bell7ffccce2021-07-28 15:37:02 +0100598 # Exhaustively test combinations of padding on each side of each dimension
599 # - the range of padding values is defined by pad_min and pad_max
600 # - for padding >9, the name format needs to be more distinctive
601 pad_min, pad_max = 0, 1
602 pad_values = [x for x in range(pad_min, pad_max + 1)]
603 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
604 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700605
Les Bell7ffccce2021-07-28 15:37:02 +0100606 for paddings in shape_pad_values:
607 name = "pad"
608 for r in range(rank):
609 before, after = paddings[r]
610 name = f"{name}{before}{after}"
611 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700612
613 return arg_list
614
615 @staticmethod
616 def agPooling(testGen, opName, shapeList, dtype):
617 arg_list = []
618
619 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800620 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700621
Les Bell7aa69f42021-09-20 10:44:07 +0100622 # Generate comprehensive argument lists
623 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
624 paddings = {x for x in itertools.product(*([p_vals] * 4))}
625 s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
626 strides = {x for x in itertools.product(*([s_vals] * 2))}
627 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 2)]
628 kernels = {x for x in itertools.product(*([k_vals] * 2))}
Eric Kunzee5e26762020-10-13 16:11:07 -0700629
Les Bell7aa69f42021-09-20 10:44:07 +0100630 # add some oversize argument values
631 bigStride = 7
632 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
633 bigKernel = 6
634 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
635 if max(shape) < 64:
636 # padding must be less than the kernel size
637 bigPadding = bigKernel - 1
638 paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
Eric Kunzee5e26762020-10-13 16:11:07 -0700639
Les Bell7aa69f42021-09-20 10:44:07 +0100640 # There are too many parameter combinations, so generate them sparsely
641 sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
642 n = 0
643 for s in sorted(list(strides)):
644 for p in sorted(list(paddings)):
645 for k in sorted(list(kernels)):
646 if (n % sparsity == 0
647 # padding must not exceed the kernel size
648 and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
649 # the padded shape must exceed the kernel size
650 and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
651 ):
652 arg_list.append(
653 (
654 "st{}_kern{}_pad{}".format(
655 "".join([str(x) for x in s]),
656 "".join([str(x) for x in k]),
657 "".join([str(x) for x in p]),
658 ),
659 [s, p, k],
660 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800661 )
Les Bell7aa69f42021-09-20 10:44:07 +0100662 n += 1
663
Eric Kunzee5e26762020-10-13 16:11:07 -0700664 return arg_list
665
666 @staticmethod
667 def agCast(testGen, opName, shapeList, inDtype):
668 arg_list = []
669
670 # Enumerate the output types here
671 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800672 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700673 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800674 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700675 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800676 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700677 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800678 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700679 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800680 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700681 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800682 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700683
684 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800685 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700686
687 return arg_list
688
689 @staticmethod
690 def agRescale(testGen, opName, shapeList, inDtype):
691 arg_list = []
692
693 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100694 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
695 if inDtype == DType.UINT8 and dtype != DType.INT8:
696 # The only output dtype for UINT8 is INT8, skip all other combinations
697 continue
698 if inDtype != DType.INT8 and dtype == DType.UINT8:
699 # The only input dtype for UINT8 is INT8, skip all other combinations
700 continue
701
Kevin Cheng550ccc52021-03-03 11:21:43 -0800702 for scale32 in [False, True]:
703 for double_round in [False, True]:
704 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700705
706 if inDtype == DType.INT48 and scale32:
707 # Illegal condition. Must be scale32=False
708 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100709 if double_round and not scale32:
710 # Illegal condition. ERROR_IF(!scale32 && double_round)
711 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700712
Kevin Cheng550ccc52021-03-03 11:21:43 -0800713 arg_list.append(
714 (
715 "out{}_sc{}_dr{}_pc{}".format(
716 DTypeNames[dtype],
717 int(scale32),
718 int(double_round),
719 int(per_channel),
720 ),
721 [dtype, scale32, double_round, per_channel],
722 )
723 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700724
725 return arg_list
726
Kevin Chengaee1fac2020-11-11 13:54:06 -0800727 @staticmethod
728 def agMul(testGen, opName, shapeList, dtype):
729 arg_list = []
730
731 if dtype is DType.INT32:
732 for p in range(testGen.args.num_rand_permutations):
733
734 shift = testGen.randInt(0, 32)
735
Kevin Cheng550ccc52021-03-03 11:21:43 -0800736 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800737 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100738 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800739
740 return arg_list
741
742 @staticmethod
743 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
744 arg_list = []
745
Kevin Cheng550ccc52021-03-03 11:21:43 -0800746 arg_list.append(("roundTrue", [True]))
747 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800748
749 return arg_list
750
Eric Kunzee5e26762020-10-13 16:11:07 -0700751 # Helper function for reshape. Gets some factors of a larger number.
752 @staticmethod
753 def getFactors(val, start=1):
754 factors = []
755
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100756 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700757 if (val % i) == 0:
758 factors.append(i)
759
760 return factors
761
762 @staticmethod
763 def agReshape(testGen, opName, shapeList, dtype):
764 arg_list = []
765
766 origShape = shapeList[0]
767
768 totalElements = 1
769 for s in origShape:
770 totalElements *= s
771
772 # This code is NOT fast. Fortunately, the numbers are fairly small.
773 factors = TosaArgGen.getFactors(totalElements)
774
775 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100776 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800777 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700778 continue
779
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100780 found = True
781 # escape_counter breaks while loop if it continues on for too long
782 escape_counter = 0
783 while found:
784 newShape = []
785 # Generate newShape ensuring it isn't a duplicate
786 remainingElements = totalElements
787 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100788 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100789 # pick rank-1 factors
790 newShape.append(shuffledFactors[0])
791 remainingElements = remainingElements // shuffledFactors[0]
792 shuffledFactors = testGen.rng.permutation(
793 TosaArgGen.getFactors(remainingElements)
794 )
795 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700796
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100797 # Toss in a -1 sometimes
798 minusOne = testGen.randInt(0, newRank * 4)
799 if minusOne < newRank:
800 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700801
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100802 # Check for duplicates
803 found = False
804 for name, other_shape in arg_list:
805 if other_shape[0] == newShape:
806 found = True
807 break
808
809 escape_counter += 1
810 if escape_counter >= 100:
811 break
812
813 if not found:
814 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700815
816 return arg_list
817
Eric Kunzee5e26762020-10-13 16:11:07 -0700818 @staticmethod
819 def agTranspose(testGen, opName, shapeList, dtype):
820 arg_list = []
821
822 ifm_shape = shapeList[0]
823
Jeremy Johnsona6185572021-06-21 15:55:35 +0100824 # Get all permutations
825 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700826
Jeremy Johnsona6185572021-06-21 15:55:35 +0100827 # Limit to possible permutations from shape dimension or argument setting
828 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700829
Jeremy Johnsona6185572021-06-21 15:55:35 +0100830 # Get random permutation generator that uses all permutations
831 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700832
Jeremy Johnsona6185572021-06-21 15:55:35 +0100833 # Create list of required amount of permutations
Kevin Chengacb550f2021-06-29 15:32:19 -0700834 arg_list = [
835 ("perm{}".format(p), [random_permutations[p].tolist()])
836 for p in range(limit)
837 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700838 return arg_list
839
840 @staticmethod
841 def agSlice(testGen, opName, shapeList, dtype):
842 arg_list = []
843
844 ifm_shape = shapeList[0]
845 rank = len(ifm_shape)
846
847 for p in range(testGen.args.num_rand_permutations):
848 begin = []
849 size = []
850
Kevin Cheng550ccc52021-03-03 11:21:43 -0800851 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700852
853 for i in range(rank):
854 if ifm_shape[i] > 1:
855 begin.append(testGen.randInt(0, ifm_shape[i]))
856 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
857
858 # Invalid slice size?
859 if size[i] == 0:
860 valid = False
861 else:
862 begin.append(0)
863 size.append(1)
864
865 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800866 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700867 return arg_list
868
869 @staticmethod
870 def agTile(testGen, opName, shapeList, dtype):
871 arg_list = []
872
873 ifm_shape = shapeList[0]
874 rank = len(ifm_shape)
875
876 for p in range(testGen.args.num_rand_permutations):
877
878 # Pick a few random, but small multiple values
879 # because otherwise this has a tendency to generate
880 # enormous tensors
881 multiples = []
882 for i in range(rank):
Matthew Haddon82ad4d62021-08-20 15:02:39 +0100883 if ifm_shape[i] > 1000:
884 # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
885 multiples.append(1)
886 elif max(ifm_shape) > 1000:
887 multiples.append(2)
888 else:
889 multiples.append(testGen.randInt(1, 4))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800890 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700891
892 return arg_list
893
894 @staticmethod
895 def agResize(testGen, opName, shapeList, dtype):
896 arg_list = []
897
898 ifm_shape = shapeList[0]
899
900 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
901
902 # Exclude illegal {mode, type} configurations. Pick legal output types
903 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100904 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700905 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800906 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700907 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100908 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700909 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800910 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800911 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800912 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700913 else:
914 continue
915
916 for outputDType in outputDTypeList:
917 for perm in range(testGen.args.num_rand_permutations):
918
919 # Randomly generate legal output dimensions and shift
920 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800921 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800922 in_center_h = (ifm_shape[1] - 1) / 2.0
923 in_center_w = (ifm_shape[2] - 1) / 2.0
924 out_center_h = (output_dims[0] - 1) / 2.0
925 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700926
Kevin Cheng77d0f762020-11-24 10:26:32 -0800927 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
928 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
929 fp_offset_y = in_center_h - fp_stride_y * out_center_h
930 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700931
Kevin Cheng77d0f762020-11-24 10:26:32 -0800932 if outputDType == DType.FLOAT:
933 shift = 0
934 stride = [0, 0]
935 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800936 stride_fp = [fp_stride_y, fp_stride_x]
937 offset_fp = [fp_offset_y, fp_offset_x]
938 arg_list.append(
939 (
940 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100941 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800942 output_dims[0],
943 output_dims[1],
944 testGen.typeStr(outputDType),
945 stride_fp[0],
946 stride_fp[1],
947 offset_fp[0],
948 offset_fp[1],
949 ),
950 [
951 m,
952 stride,
953 offset,
954 shift,
955 stride_fp,
956 offset_fp,
957 output_dims,
958 dtype,
959 outputDType,
960 ],
961 )
962 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800963 else:
964 shift = 11
965 unit = float(1 << shift)
966 stride_y = int(round(fp_stride_y * unit))
967 stride_x = int(round(fp_stride_x * unit))
968 offset_y = int(round(fp_offset_y * unit))
969 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700970
Kevin Cheng550ccc52021-03-03 11:21:43 -0800971 while (
972 stride_y >= 32768
973 or stride_x >= 32768
974 or offset_y >= 32768
975 or offset_x >= 32768
976 or offset_y < -32768
977 or offset_x < -32768
978 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800979 shift = shift - 1
980 unit = float(1 << shift)
981 stride_y = int(round(fp_stride_y * unit))
982 stride_x = int(round(fp_stride_x * unit))
983 offset_y = int(round(fp_offset_y * unit))
984 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700985
Kevin Cheng550ccc52021-03-03 11:21:43 -0800986 stride = [stride_y, stride_x]
987 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800988
989 stride_fp = [0.0, 0.0]
990 offset_fp = [0.0, 0.0]
991
Kevin Cheng550ccc52021-03-03 11:21:43 -0800992 arg_list.append(
993 (
994 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100995 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800996 shift,
997 output_dims[0],
998 output_dims[1],
999 testGen.typeStr(outputDType),
1000 stride[0],
1001 stride[1],
1002 offset[0],
1003 offset[1],
1004 ),
1005 [
1006 m,
1007 stride,
1008 offset,
1009 shift,
1010 stride_fp,
1011 offset_fp,
1012 output_dims,
1013 dtype,
1014 outputDType,
1015 ],
1016 )
1017 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001018
1019 return arg_list
1020
1021 def agCondIf(testGen, opName, shapeList, dtype):
1022 # CondIf generates the condition values here.
1023 # Convert to tensors in the build function, along with the
1024 # then and else blocks
1025 arg_list = []
1026
1027 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001028 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001029
1030 return arg_list
1031
1032 def agWhileLoop(testGen, opName, shapeList, dtype):
1033 # While loop: 0 iterations, 1, more than 1
1034 arg_list = []
1035
1036 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001037 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001038
1039 return arg_list
1040
Matthew Haddonb724efc2021-08-25 16:40:29 +01001041class TosaInvalidValidator:
1042
1043 @staticmethod
1044 def ivWrongDataTypeOrModeResize(**kwargs):
1045 input_dtype = kwargs["input_dtype"]
1046 args = kwargs["args"]
1047 mode = args[0]
1048 stride = args[1]
1049 stride_fp = args[4]
1050 output_dtype = args[8]
1051
1052 if mode == ResizeMode.BILINEAR:
1053 # Invalid output data type / Invalid input datatype
1054 return (
1055 not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
1056 not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
1057 not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
1058 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1059 )
1060 elif mode == ResizeMode.NEAREST:
1061 # Invalid output data type / Invalid input datatype
1062 return (
1063 (input_dtype != output_dtype) or
1064 (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
1065 )
1066 else:
1067 # Invalid resize mode
1068 return True
1069
1070 @staticmethod
1071 def ivBadStride(**kwargs):
1072 input_dtype = kwargs["input_dtype"]
1073 args = kwargs["args"]
1074 stride_x = args[1][0]
1075 stride_y = args[1][1]
1076 stride_fp_x = args[4][0]
1077 stride_fp_y = args[4][1]
1078
1079 if input_dtype == DType.FLOAT:
1080 if stride_fp_x <= 0 or stride_fp_y <= 0:
1081 # Negative or zero stride
1082 return True
1083 else:
1084 if stride_x <= 0 or stride_y <= 0:
1085 # Negative or zero stride
1086 return True
1087 return False
1088
1089
1090
1091
1092 @staticmethod
1093 def ivHeightWidthSmallerZero(**kwargs):
1094 opName = kwargs['opName']
1095
1096 inputShapes = kwargs['shapeList']
1097 input = inputShapes[0]
1098 if not opName.endswith("pool2d"):
1099 filter = inputShapes[1]
1100
1101 args = kwargs['args']
1102 strides = args[0]
1103 padding = args[1]
1104 dilations = args[2]
1105 if opName.endswith("pool2d"):
1106 kernel = args[2]
1107
1108 if opName.startswith('conv2d'):
1109 h = (
1110 input[1]
1111 - filter[1]
1112 - (filter[1] - 1) * (dilations[0] - 1)
1113 + padding[0]
1114 + padding[1]
1115 ) // strides[0] + 1
1116
1117 w = (
1118 input[2]
1119 - filter[2]
1120 - (filter[2] - 1) * (dilations[1] - 1)
1121 + padding[2]
1122 + padding[3]
1123 ) // strides[1] + 1
1124 elif opName.startswith("depthwise_conv2d"):
1125 h = (
1126 input[1]
1127 - filter[0]
1128 - (filter[0] - 1) * (dilations[0] - 1)
1129 + padding[0]
1130 + padding[1]
1131 ) // strides[0] + 1
1132
1133 w = (
1134 input[2]
1135 - filter[1]
1136 - (filter[1] - 1) * (dilations[1] - 1)
1137 + padding[2]
1138 + padding[3]
1139 ) // strides[1] + 1
1140 elif opName.endswith("pool2d"):
1141 h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
1142 w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
1143 else:
1144 assert False, "Unrecognized Op"
1145
1146 if h <= 0 or w <= 0:
1147 # Invalid parameter combination
1148 return True
1149 return False
1150
1151 @staticmethod
1152 def ivNonPositiveOutputShape(**kwargs):
1153 args = kwargs['args']
1154 output_shape = args[3]
1155 if output_shape[1] <= 0 or output_shape[2] <= 0:
1156 # Negative output shape
1157 return True
1158 return False
1159
1160
Kevin Cheng550ccc52021-03-03 11:21:43 -08001161
Eric Kunzee5e26762020-10-13 16:11:07 -07001162class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001163 # Maximum rank of tensor supported by test generator.
1164 TOSA_TENSOR_MAX_RANK = 6
1165
Eric Kunzee5e26762020-10-13 16:11:07 -07001166 def __init__(self, args):
1167 self.args = args
1168 self.basePath = args.output_dir
1169 self.random_seed = args.random_seed
1170 self.ser = None
1171 self.rng = np.random.default_rng(self.random_seed)
1172 self.createDynamicOpLists()
1173 self.initOpListDefaults()
1174 self.quantGen = TosaQuantGen()
1175 # Force makeShape to do a specific starting shape
1176 self.targetted_shape = None
1177
1178 def createSerializer(self, opName, testPath):
1179 self.testPath = os.path.join(opName, testPath)
1180
1181 fullPath = os.path.join(self.basePath, self.testPath)
1182 os.makedirs(fullPath, exist_ok=True)
1183 self.ser = ts.TosaSerializer(fullPath)
1184
1185 def getSerializer(self):
1186 return self.ser
1187
1188 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001189 with open(
1190 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
1191 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -07001192 fd.write(self.ser.serialize())
1193
Kevin Cheng550ccc52021-03-03 11:21:43 -08001194 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
1195 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -07001196
Matthew Haddon74567092021-07-16 15:38:20 +01001197 def resetRNG(self, seed=None):
1198 if seed == None:
1199 seed = self.random_seed + 1
1200 self.rng = np.random.default_rng(seed)
1201
Eric Kunzee5e26762020-10-13 16:11:07 -07001202 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -07001203 if dtype == DType.BOOL:
1204 np_dt = np.bool
1205 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -07001206 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001207 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001208 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001209 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001210 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
1211 elif dtype == DType.UINT8:
1212 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001213 elif dtype == DType.INT16:
1214 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
1215 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001216 return np.int32(
1217 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
1218 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001219 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001220 return np.int64(
1221 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
1222 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001223 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001224 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001225 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001226 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001227
Kevin Cheng989cb052021-04-28 16:29:44 -07001228 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001229 placeholders = []
1230
Kevin Cheng989cb052021-04-28 16:29:44 -07001231 assert len(shape_list) == len(dtype_list)
1232
1233 for idx, shape in enumerate(shape_list):
1234 arr = self.getRandTensor(shape, dtype_list[idx])
1235 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001236
1237 return placeholders
1238
Kevin Cheng989cb052021-04-28 16:29:44 -07001239 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -07001240 consts = []
1241
Kevin Cheng989cb052021-04-28 16:29:44 -07001242 assert len(shape_list) == len(dtype_list)
1243
1244 for idx, shape in enumerate(shape_list):
1245 arr = self.getRandTensor(shape, dtype_list[idx])
1246 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001247
1248 return consts
1249
1250 def makeShape(self, rank):
1251 if self.targetted_shape:
1252 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001253 return np.int32(
1254 self.rng.integers(
1255 low=self.args.tensor_shape_range[0],
1256 high=self.args.tensor_shape_range[1],
1257 size=rank,
1258 )
1259 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001260
1261 def setTargetShape(self, shape):
1262 self.targetted_shape = shape
1263
1264 def randInt(self, low=0, high=256):
1265 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
1266
1267 def getRandNumberDType(self, dtype):
1268 if dtype == DType.FLOAT:
1269 return self.rng.random()
1270 elif dtype == DType.BOOL:
1271 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -07001272 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -07001273 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -07001274 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001275 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +01001276 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -07001277 elif dtype == DType.INT16:
1278 low, high = (-32768, 32768)
1279 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001280 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001281 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001282 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001283 # Special size
1284 return np.int64(self.rng.integers(low, high, size=1))[0]
1285 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001286 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001287
1288 return np.int32(self.rng.integers(low, high, size=1))[0]
1289
1290 def shapeStr(self, shape):
1291
1292 sStr = []
1293 # Convert to strings
1294 for i in shape:
1295 sStr.append(str(i))
1296
Kevin Cheng550ccc52021-03-03 11:21:43 -08001297 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001298
1299 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001300 if isinstance(t, list):
1301 assert len(t) >= 2
1302 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001303 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001304 if t == DType.BOOL:
1305 return "b"
1306 elif t == DType.INT4:
1307 return "i4"
1308 elif t == DType.INT8:
1309 return "i8"
1310 elif t == DType.UINT8:
1311 return "u8"
1312 elif t == DType.INT16:
1313 return "i16"
1314 elif t == DType.INT32:
1315 return "i32"
1316 elif t == DType.INT48:
1317 return "i48"
1318 elif t == DType.FLOAT:
1319 return "float"
1320 else:
1321 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001322
1323 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001324 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001325 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001326 return 4
1327 elif t == DType.INT8:
1328 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001329 elif t == DType.UINT8:
1330 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001331 elif t == DType.INT16:
1332 return 16
1333 elif t == DType.INT32:
1334 return 32
1335 elif t == DType.INT48:
1336 return 48
1337 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001338 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001339
1340 # Argument generators
1341 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1342 # Where the string descriptor is used to generate the test name and
1343 # The build_fcn_arg_list is expanded and passed to the operator test
1344 # build function
1345
Kevin Cheng550ccc52021-03-03 11:21:43 -08001346 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001347 result_tens = OutputShaper.unaryOp(self.ser, a)
1348 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1349 return result_tens
1350
1351 def build_binary_broadcast(self, op, a, b):
1352 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1353 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1354 return result_tens
1355
1356 def build_binary_nonbroadcast(self, op, a, b):
1357 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1358 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1359 return result_tens
1360
Kevin Chengaee1fac2020-11-11 13:54:06 -08001361 def build_arithmetic_right_shift(self, op, a, b, round):
1362 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1363
1364 attr = ts.TosaSerializerAttribute()
1365 attr.ArithmeticRightShiftAttribute(round)
1366
1367 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1368 return result_tens
1369
1370 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001371 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1372
1373 # Special for multiply:
1374 # Force the result to INT32 for INT types
1375 if a.dtype != DType.FLOAT:
1376 result_tens.setDtype(DType.INT32)
1377
Kevin Chengaee1fac2020-11-11 13:54:06 -08001378 attr = ts.TosaSerializerAttribute()
1379 attr.MulAttribute(shift)
1380
1381 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001382 return result_tens
1383
1384 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001385 # Constant size depending on type, random values
1386 if a.dtype == DType.INT16:
Kevin Chengacb550f2021-06-29 15:32:19 -07001387 table_dtype = DType.INT16
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001388 table_arr = self.getRandTensor([513], table_dtype)
1389 else:
1390 assert a.dtype == DType.INT8
1391 table_dtype = DType.INT8
1392 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001394 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1395 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001396 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1397
1398 return result_tens
1399
1400 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001401 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1402 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001403 return result_tens
1404
1405 def build_comparison(self, op, a, b):
1406 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1407 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1408 return result_tens
1409
1410 def build_argmax(self, op, a, axis):
1411 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1412
1413 attr = ts.TosaSerializerAttribute()
1414 attr.AxisAttribute(axis)
1415
1416 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1417 return result_tens
1418
Matthew Haddonb724efc2021-08-25 16:40:29 +01001419 def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001420 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1421
1422 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001423 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001424
1425 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1426 return result_tens
1427
1428 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001429 assert len(padding) == 4
1430 result_tens = OutputShaper.conv2dOp(
1431 self.ser, ifm, filter, strides, padding, dilations
1432 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001433
1434 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001435 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001436
Kevin Cheng550ccc52021-03-03 11:21:43 -08001437 self.ser.addOperator(
1438 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1439 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001440 return result_tens
1441
Kevin Cheng1533b852021-09-01 12:51:58 -07001442 def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
1443 assert len(padding) == 6
1444 result_tens = OutputShaper.conv3dOp(
1445 self.ser, ifm, filter, strides, padding, dilations
1446 )
1447
1448 attr = ts.TosaSerializerAttribute()
1449 attr.ConvAttribute(padding, strides, dilations)
1450
1451 self.ser.addOperator(
1452 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1453 )
1454 return result_tens
1455
Kevin Cheng550ccc52021-03-03 11:21:43 -08001456 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001457 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001458 ):
1459 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001460 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1461
1462 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001463 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07001464
Kevin Cheng550ccc52021-03-03 11:21:43 -08001465 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001466 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001467 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001468 return result_tens
1469
Kevin Cheng550ccc52021-03-03 11:21:43 -08001470 def build_depthwise_conv2d(
1471 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1472 ):
1473 result_tens = OutputShaper.depthwiseConv2dOp(
1474 self.ser, ifm, filter, strides, padding, dilations
1475 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001476
1477 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -07001478 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -07001479
Kevin Cheng550ccc52021-03-03 11:21:43 -08001480 self.ser.addOperator(
1481 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1482 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001483 return result_tens
1484
1485 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1486 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1487
Kevin Cheng550ccc52021-03-03 11:21:43 -08001488 self.ser.addOperator(
1489 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1490 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001491 return result_tens
1492
1493 def build_matmul(self, op, a, b, qinfo):
1494 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1495 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1496 return result_tens
1497
1498 def build_reduce(self, op, a, axis):
1499 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1500
1501 attr = ts.TosaSerializerAttribute()
1502 attr.AxisAttribute(axis)
1503
1504 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1505 return result_tens
1506
1507 def build_clamp(self, op, a):
1508 result_tens = OutputShaper.unaryOp(self.ser, a)
1509
1510 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01001511 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001512
1513 if a.dtype == DType.FLOAT:
1514 attr.ClampAttribute(0, 0, min(v), max(v))
1515 else:
1516 attr.ClampAttribute(min(v), max(v), 0, 0)
1517
1518 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1519 return result_tens
1520
1521 def build_leaky_relu(self, op, a):
1522 result_tens = OutputShaper.unaryOp(self.ser, a)
1523 attr = ts.TosaSerializerAttribute()
1524
1525 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1526
1527 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1528 return result_tens
1529
1530 # Needs an additional type/input
1531 def build_prelu(self, op, a):
1532 result_tens = OutputShaper.unaryOp(self.ser, a)
1533
1534 self.ser.addOperator(op, [a.name], [result_tens.name])
1535 return result_tens
1536
1537 def build_relun(self, op, a):
1538 result_tens = OutputShaper.unaryOp(self.ser, a)
1539
1540 attr = ts.TosaSerializerAttribute()
1541
1542 if a.dtype == DType.FLOAT:
1543 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1544 else:
1545 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1546
1547 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1548 return result_tens
1549
1550 def build_sigmoid(self, op, a):
1551 result_tens = OutputShaper.unaryOp(self.ser, a)
1552 self.ser.addOperator(op, [a.name], [result_tens.name])
1553 return result_tens
1554
1555 def build_tanh(self, op, a):
1556 result_tens = OutputShaper.unaryOp(self.ser, a)
1557 self.ser.addOperator(op, [a.name], [result_tens.name])
1558 return result_tens
1559
Matthew Haddon818ab902021-07-27 09:12:49 +01001560 def build_concat(self, op, *a):
Kevin Cheng93a16282021-08-31 16:14:03 -07001561 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001562
1563 # To store variable length list of input tensors we need to store axis along with it
1564 axis = a[-1]
1565 a = a[:-1]
1566
1567 result_tens = OutputShaper.concatOp(self.ser, axis, *a)
Eric Kunzee5e26762020-10-13 16:11:07 -07001568
1569 attr = ts.TosaSerializerAttribute()
1570 attr.AxisAttribute(axis)
1571
Matthew Haddon818ab902021-07-27 09:12:49 +01001572 input_tensor_names = []
1573 for tensor in a:
1574 input_tensor_names.append(tensor.name)
1575
1576 self.ser.addOperator(op, input_tensor_names, [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001577
1578 def build_pad(self, op, a, padding, qinfo):
1579 result_tens = OutputShaper.padOp(self.ser, a, padding)
1580
1581 # Need to turn the padding array into a TOSA tensor here.
1582 # This is one of the few tensor operands that does not get
1583 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001584 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001585
Kevin Cheng550ccc52021-03-03 11:21:43 -08001586 self.ser.addOperator(
1587 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1588 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001589
1590 def build_reshape(self, op, a, newShape):
1591 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1592
1593 attr = ts.TosaSerializerAttribute()
1594 attr.ReshapeAttribute(newShape)
1595
1596 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1597 return result_tens
1598
1599 def build_reverse(self, op, a, axis):
1600 result_tens = OutputShaper.unaryOp(self.ser, a)
1601
1602 attr = ts.TosaSerializerAttribute()
1603 attr.AxisAttribute(axis)
1604
1605 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1606 return result_tens
1607
1608 def build_transpose(self, op, a, perms):
1609 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1610
Kevin Cheng550ccc52021-03-03 11:21:43 -08001611 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001612
1613 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1614 return result_tens
1615
1616 def build_slice(self, op, a, begin, size):
1617 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1618
1619 attr = ts.TosaSerializerAttribute()
1620 attr.SliceAttribute(begin, size)
1621
1622 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1623 return result_tens
1624
1625 def build_tile(self, op, a, multiples):
1626 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1627
1628 attr = ts.TosaSerializerAttribute()
1629 attr.TileAttribute(multiples)
1630
1631 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1632 return result_tens
1633
Kevin Cheng77d0f762020-11-24 10:26:32 -08001634 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001635
1636 # Create a new indicies tensor
1637 # here with data that doesn't exceed the dimensions of the values tensor
1638
Kevin Cheng550ccc52021-03-03 11:21:43 -08001639 K = values.shape[1] # K
1640 W = self.randInt(
1641 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1642 ) # W
1643 indicies_arr = np.int32(
1644 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1645 ) # (N, W)
1646 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001647
Kevin Cheng77d0f762020-11-24 10:26:32 -08001648 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001649
Kevin Cheng77d0f762020-11-24 10:26:32 -08001650 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001651
1652 return result_tens
1653
Kevin Cheng77d0f762020-11-24 10:26:32 -08001654 def build_scatter(self, op, values_in, input):
1655
1656 # Create a new indicies tensor
1657 # here with data that doesn't exceed the dimensions of the values_in tensor
1658
Kevin Cheng550ccc52021-03-03 11:21:43 -08001659 K = values_in.shape[1] # K
1660 W = input.shape[1] # W
1661 indicies_arr = np.int32(
1662 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1663 ) # (N, W)
1664 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001665
1666 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1667
Kevin Cheng550ccc52021-03-03 11:21:43 -08001668 self.ser.addOperator(
1669 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1670 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001671
1672 return result_tens
1673
Kevin Cheng550ccc52021-03-03 11:21:43 -08001674 def build_resize(
1675 self,
1676 op,
1677 input,
1678 mode,
1679 stride,
1680 offset,
1681 shift,
1682 stride_fp,
1683 offset_fp,
1684 output_dims,
1685 input_dtype,
1686 output_dtype,
1687 ):
1688 result_tens = OutputShaper.resizeOp(
1689 self.ser,
1690 input,
1691 mode,
1692 stride,
1693 offset,
1694 shift,
1695 stride_fp,
1696 offset_fp,
1697 output_dims,
1698 input_dtype,
1699 output_dtype,
1700 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001701
1702 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001703
Kevin Cheng550ccc52021-03-03 11:21:43 -08001704 attr.ResizeAttribute(
1705 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1706 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001707
1708 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1709 return result_tens
1710
1711 def build_identityn(self, op, val, val2):
1712
Kevin Cheng550ccc52021-03-03 11:21:43 -08001713 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001714 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001715 self.ser.addOperator(
1716 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1717 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001718 return result_tens
1719
1720 def build_placeholder(self, op, val):
1721 # Add an identity op to avoid warning in the reference model
1722 return self.build_unary(Op.IDENTITY, val)
1723
1724 # Type Conversion
1725 def build_cast(self, op, val, out_dtype):
1726 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1727 self.ser.addOperator(op, [val.name], [result_tens.name])
1728 return result_tens
1729
1730 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1731 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1732
1733 if per_channel:
1734 nc = val.shape[-1]
1735 else:
1736 nc = 1
1737
1738 in_type_width = self.typeWidth(val.dtype)
1739 out_type_width = self.typeWidth(out_dtype)
1740
Kevin Cheng3a478572021-01-22 17:21:02 -08001741 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001742 input_zp = self.randInt(-128, 128)
1743 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001744 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001745 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001746 in_type_width = in_type_width + 1
1747 else:
1748 input_zp = 0
1749
Kevin Cheng3a478572021-01-22 17:21:02 -08001750 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001751 output_zp = self.randInt(-128, 128)
1752 out_type_width = out_type_width + 1
1753 elif out_dtype == DType.UINT8:
1754 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001755 out_type_width = out_type_width + 1
1756 else:
1757 output_zp = 0
1758
1759 # Calculate scale based on:
1760 # scale = a *(2^output_width)/(2^input_width))
1761
1762 a = np.float32(self.rng.random(size=[nc]))
1763 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1764
1765 if scale32:
1766 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001767 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001768 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1769 else:
1770 # Cap the scaling at 2^15 - 1 for scale16
1771 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1772
Kevin Cheng550ccc52021-03-03 11:21:43 -08001773 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001774
1775 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1776 shift_arr = np.int32(np.zeros(shape=[nc]))
1777
1778 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001779 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1780 scale_arr[i], scale32
1781 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001782
Kevin Cheng550ccc52021-03-03 11:21:43 -08001783 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001784
1785 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001786 attr.RescaleAttribute(
1787 input_zp,
1788 output_zp,
1789 multiplier_arr,
1790 shift_arr,
1791 scale32,
1792 double_round,
1793 per_channel,
1794 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001795
1796 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1797 return result_tens
1798
1799 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1800 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1801 # (except for the generated shap) and the condition. Build Then/Else blocks
1802 # and fill them with const nodes for the body.
1803
1804 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001805 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001806
1807 # Make then/else tensors
1808 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01001809 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1810 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001811
1812 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001813 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001814
1815 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001816 then_block = "THEN_BLOCK"
1817 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001818 attr = ts.TosaSerializerAttribute()
1819 attr.CondIfAttribute(then_block, else_block)
1820
1821 # Finally, build the op and the two blocks
1822 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1823
1824 self.ser.startBasicBlock(then_block)
1825 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001826 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001827 self.ser.addOutputTensor(then_tens)
1828
1829 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001830 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001831 self.ser.addOutputTensor(else_tens)
1832
1833 return result_tens
1834
1835 def build_cond_if_binary(self, op, a, b, cond):
1836 # For cond_if with a binary op in the then/else blocks, take a and b and
1837 # alternately add or subtract them based on the condition
1838
1839 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001840 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001841
Kevin Cheng550ccc52021-03-03 11:21:43 -08001842 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001843 self.ser.currBasicBlock.addOutput(result_tens.name)
1844
1845 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001846 then_block = "THEN_BLOCK"
1847 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001848 attr = ts.TosaSerializerAttribute()
1849 attr.CondIfAttribute(then_block, else_block)
1850
1851 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001852 self.ser.addOperator(
1853 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1854 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001855
1856 self.ser.startBasicBlock(then_block)
1857 self.ser.addInputTensor(a)
1858 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001859 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001860 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1861
1862 self.ser.startBasicBlock(else_block)
1863 self.ser.addInputTensor(a)
1864 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001865 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001866 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1867
1868 return result_tens
1869
1870 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001871 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001872
Kevin Cheng550ccc52021-03-03 11:21:43 -08001873 cond_block = "COND_BLOCK"
1874 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001875
1876 attr = ts.TosaSerializerAttribute()
1877 attr.WhileLoopAttribute(cond_block, body_block)
1878
1879 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001880 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001881 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001882 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001883
1884 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001885 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1886 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1887 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001888
1889 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001890 self.ser.addOperator(
1891 op,
1892 [iter.name, a.name, acc.name],
1893 [iter_out.name, a_out.name, acc_out.name],
1894 attr,
1895 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001896 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001897
1898 # COND block (input: iter, output: cond_tens )
1899 self.ser.startBasicBlock(cond_block)
1900 self.ser.addInputTensor(iter)
1901 self.ser.addInputTensor(a)
1902 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001903 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1904 cond_tens = self.ser.addOutput([], DType.BOOL)
1905 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001906
1907 # BODY block (input: a, acc, iter, output: a, acc, iter)
1908 # Note that local intermediate tensors need to be declared here for the outputs
1909 self.ser.startBasicBlock(body_block)
1910 self.ser.addInputTensor(iter)
1911 self.ser.addInputTensor(a)
1912 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001913 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1914 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1915 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001916 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1917 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1918 self.ser.addOutputTensor(iter_body_out)
1919 self.ser.addOutputTensor(a)
1920 self.ser.addOutputTensor(acc_body_out)
1921
1922 return acc_out
1923
Kevin Cheng550ccc52021-03-03 11:21:43 -08001924 def genOpTestList(
Matthew Haddon74567092021-07-16 15:38:20 +01001925 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
Kevin Cheng550ccc52021-03-03 11:21:43 -08001926 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001927
1928 try:
1929 op = self.TOSA_OP_LIST[opName]
1930 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001931 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001932
1933 # Initialize a new random number generator
1934 self.rng = np.random.default_rng(self.random_seed)
1935
Kevin Cheng550ccc52021-03-03 11:21:43 -08001936 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001937
1938 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001939 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001940
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001941 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1942 default_test_rank_range = range(1, 5)
1943
Eric Kunzee5e26762020-10-13 16:11:07 -07001944 # Test list consists of a tuple of:
1945 # (opName, testNameStr, dtype, shapeList, argumentsList)
1946 testList = []
1947
1948 if not shapeFilter:
1949 shapeFilter = [None]
1950
Matthew Haddon74567092021-07-16 15:38:20 +01001951 # Positive test loop
1952 if testType in ['positive', 'both']:
1953 for r in range(rmin, rmax + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -07001954
Matthew Haddon74567092021-07-16 15:38:20 +01001955 # Filter out the rank?
1956 if rankFilter is not None and r not in rankFilter:
1957 continue
Kevin Cheng1533b852021-09-01 12:51:58 -07001958 if opName.startswith("conv3d"):
1959 assert r == 5, "conv3d test must have input rank == 5"
1960 elif (
Matthew Haddon74567092021-07-16 15:38:20 +01001961 rankFilter is None
1962 and shapeFilter[0] is None
1963 and r not in default_test_rank_range
1964 ):
1965 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001966
Matthew Haddon74567092021-07-16 15:38:20 +01001967 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001968
Matthew Haddon74567092021-07-16 15:38:20 +01001969 # Filter tests based on dtype?
1970 if dtypeFilter is not None:
1971 if not (
1972 t in dtypeFilter
1973 or (isinstance(t, list) and t[0] in dtypeFilter)
1974 ):
1975 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001976
Matthew Haddon74567092021-07-16 15:38:20 +01001977 # Create the placeholder and const tensors
1978 for shape in shapeFilter:
1979 # A None shape chooses a random shape of a given rank
Eric Kunzee5e26762020-10-13 16:11:07 -07001980
Matthew Haddon74567092021-07-16 15:38:20 +01001981 # Filter out by rank
1982 if shape is not None and len(shape) != r:
1983 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001984
Matthew Haddon74567092021-07-16 15:38:20 +01001985 self.setTargetShape(shape)
1986 shapeList = tgen_fcn(self, op, r)
Eric Kunzee5e26762020-10-13 16:11:07 -07001987
Matthew Haddon74567092021-07-16 15:38:20 +01001988 shapeStr = self.shapeStr(shapeList[0])
1989 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07001990
Matthew Haddon74567092021-07-16 15:38:20 +01001991 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1992 argList = []
1993 if agen_fcn:
1994 argList = agen_fcn(self, opName, shapeList, t)
Eric Kunzee5e26762020-10-13 16:11:07 -07001995 else:
Matthew Haddon74567092021-07-16 15:38:20 +01001996 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001997
Matthew Haddon74567092021-07-16 15:38:20 +01001998 for argStr, args in argList:
1999 if argStr:
2000 testStr = "{}_{}_{}_{}".format(
2001 opName, shapeStr, typeStr, argStr
2002 )
2003 else:
2004 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
2005
2006 testList.append((opName, testStr, t, shapeList, args))
2007
Matthew Haddonb724efc2021-08-25 16:40:29 +01002008 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2009 if "invalid_test_validators" in op:
2010 invalid_test_validators = op["invalid_test_validators"]
2011 clean_testList = []
2012 for test in testList:
2013 for validator_fcn in invalid_test_validators:
2014 remove_test = False
2015 if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[3], args=test[4]):
2016 remove_test = True
2017 if not remove_test:
2018 clean_testList.append(test)
2019 testList = clean_testList
2020
Matthew Haddon74567092021-07-16 15:38:20 +01002021 # Reset RNG so both positive and negative tests are reproducible
2022 self.resetRNG()
2023 # Negative test loop
2024 if testType in ['negative', 'both']:
2025 print("Negative tests unsupported")
Eric Kunzee5e26762020-10-13 16:11:07 -07002026
2027 return testList
2028
Kevin Cheng989cb052021-04-28 16:29:44 -07002029 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07002030 try:
2031 op = self.TOSA_OP_LIST[opName]
2032 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002033 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002034
2035 # Create a serializer
2036 self.createSerializer(opName, testStr)
2037
Kevin Cheng550ccc52021-03-03 11:21:43 -08002038 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
2039 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002040 num_operands = pCount + cCount
2041
2042 if isinstance(dtype_or_dtypeList, list):
2043 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002044 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002045 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002046 else:
2047 dtypeList = [dtype_or_dtypeList] * (num_operands)
2048
Kevin Cheng93a16282021-08-31 16:14:03 -07002049 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002050 assert (
2051 len(shapeList) == num_operands
2052 ), "shapeList length {} must match number of operands {}".format(
2053 len(shapeList), num_operands
2054 )
2055 assert (
2056 len(dtypeList) == num_operands
2057 ), "dtypeList length {} must match number of operands {}".format(
2058 len(dtypeList), num_operands
2059 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002060
2061 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002062 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002063 except KeyError:
2064 qgen = None
2065
2066 # Build the random tensor operands and the test
2067 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002068
Jeremy Johnsonef509a42021-09-07 13:59:47 +01002069 if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32:
2070 # Make sure the operation does not cause value saturation - where
2071 # the number wraps due to limited number of bits to store the answer
2072 assert (
2073 pCount == 2 and cCount == 0
2074 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
2075
2076 placeholders = []
2077 add = (op["op"] == Op.ADD)
2078 a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2079 b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2080 if add:
2081 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
2082 else:
2083 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
2084
2085 # Work out the saturation limits
2086 max_i32 = (1 << 31)-1
2087 min_i32 = -(1 << 31)
2088 max_arr = np.full(shapeList[1], max_i32)
2089 min_arr = np.full(shapeList[1], min_i32)
2090
2091 # Find how much values exceed the maximum/minimums
2092 sat_max_arr = np.maximum(res_arr - max_arr, 0)
2093 sat_min_arr = np.minimum(res_arr - min_arr, 0)
2094
2095 if not add:
2096 # Swap saturation values and negate values as we need to perform opposite operations
2097 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
2098
2099 # Create new array of unsaturated values by clipping values as needed
2100 b_unsat_arr = b_arr
2101 if (sat_max_arr != 0).any():
2102 # Clip values that cause saturation
2103 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
2104 # Reduce axes in unsaturated tensor to match original tensor
2105 for axis, dim in enumerate(b_arr.shape):
2106 if dim != b_unsat_arr.shape[axis]:
2107 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2108 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
2109
2110 if (sat_min_arr != 0).any():
2111 # Clip values that cause saturation
2112 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
2113 # Reduce axes in unsaturated tensor to match original tensor
2114 for axis, dim in enumerate(b_arr.shape):
2115 if dim != b_unsat_arr.shape[axis]:
2116 assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
2117 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
2118
2119 placeholders.append(
2120 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2121 )
2122 placeholders.append(
2123 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
2124 )
2125
2126 tens.extend(placeholders)
2127 elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
2128 # Force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002129 assert (
2130 pCount == 2 and cCount == 0
2131 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08002132
2133 placeholders = []
2134 for idx, shape in enumerate(shapeList[:]):
2135 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07002136 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002137 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002138 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002139 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07002140 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08002141 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
2142 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002143 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002144 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002145 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07002146 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002147
2148 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01002149 elif op["op"] == Op.SELECT:
2150 # Set datatype of condition tensor to boolean
2151 dtypeList[0] = DType.BOOL
2152 tens.extend(
2153 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2154 )
2155 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Matthew Haddon459443c2021-08-23 16:43:13 +01002156 elif op["op"] == Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002157 assert (
2158 pCount == 2 and cCount == 0
Matthew Haddon459443c2021-08-23 16:43:13 +01002159 ), "Op.INTDIV must have 2 placeholders, 0 consts"
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002160
2161 placeholders = []
2162
Matthew Haddon459443c2021-08-23 16:43:13 +01002163 # Two invalid cases for Op.INTDIV:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002164 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07002165 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002166 while True:
2167 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
2168 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
2169
2170 if (divisor_arr == 0).any():
2171 continue
2172
Kevin Cheng47315e12021-05-13 17:41:28 -07002173 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002174 continue
2175
2176 break
2177
2178 placeholders.append(
2179 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
2180 )
2181 placeholders.append(
2182 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
2183 )
2184
2185 tens.extend(placeholders)
2186 elif op["op"] == Op.MUL:
2187 assert (
2188 pCount == 2 and cCount == 0
2189 ), "Op.MUL must have 2 placeholders, 0 consts"
2190
2191 if dtypeList[0] == DType.FLOAT:
2192 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
2193 else:
2194 placeholders = []
2195
2196 # Make sure multiply result in int32 range
2197 shift = testArgs[0]
2198 if dtypeList[0] == DType.INT8:
2199 num_bits = 8
2200 elif dtypeList[0] == DType.INT16:
2201 num_bits = 16
2202 elif dtypeList[0] == DType.INT32:
2203 num_bits = 32
2204 else:
2205 raise Exception("OpMul: invalid input dtype")
2206
2207 for idx, shape in enumerate(shapeList[:]):
2208 low = -(2 ** (num_bits - 1))
2209 high = (2 ** (num_bits - 1)) - 1
2210
2211 a_arr = np.int32(
2212 self.rng.integers(low=low, high=high, size=shapeList[0])
2213 )
2214 b_arr = np.int32(
2215 self.rng.integers(low=low, high=high, size=shapeList[1])
2216 )
2217
2218 i = 0
2219 while True:
2220
2221 a_arr_64 = a_arr.astype(np.int64)
2222 b_arr_64 = b_arr.astype(np.int64)
2223
2224 if shift > 0:
2225 rounding = 1 << (shift - 1)
2226 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
2227 else:
2228 result_arr = a_arr_64 * b_arr_64
2229
2230 if (result_arr > -(2 ** 31)).all() and (
2231 result_arr <= ((2 ** 31) - 1)
2232 ).all():
2233 break
2234
2235 i = i + 1
2236 a_arr = a_arr // 2
2237 b_arr = b_arr // 2
2238
2239 placeholders.append(
2240 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
2241 )
2242 placeholders.append(
2243 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
2244 )
2245
2246 tens.extend(placeholders)
Matthew Haddon818ab902021-07-27 09:12:49 +01002247 elif op["op"] == Op.CONCAT:
2248 count = len(shapeList) - self.args.num_const_inputs_concat
2249 if count < 1:
2250 count = 1
2251 if self.args.num_const_inputs_concat == 0:
2252 count = len(shapeList)
2253
2254 shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
2255 tens.extend(
2256 self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
2257 )
2258 tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
Kevin Chengaee1fac2020-11-11 13:54:06 -08002259 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002260 tens.extend(
2261 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
2262 )
2263 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002264
2265 if qgen is not None:
Les Bell30e46802021-07-23 09:43:31 +01002266 qinfo = qgen(self, op, dtype_or_dtypeList)
Eric Kunzee5e26762020-10-13 16:11:07 -07002267 else:
2268 qinfo = None
2269
2270 try:
2271 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002272 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07002273 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002274 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07002275 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002276 print(
2277 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
2278 build_fcn, tens, testArgs
2279 )
2280 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002281 raise e
2282
2283 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08002284 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07002285
2286 def createDynamicOpLists(self):
2287
2288 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002289 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002290
Kevin Cheng1533b852021-09-01 12:51:58 -07002291 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002292 testName = "conv2d_{}x{}".format(k[0], k[1])
2293 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2294 self.TOSA_OP_LIST[testName]["filter"] = k
2295 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002296
Kevin Cheng550ccc52021-03-03 11:21:43 -08002297 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2298 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2299 "depthwise_conv2d_TEMPLATE"
2300 ].copy()
2301 self.TOSA_OP_LIST[testName]["filter"] = k
2302 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002303
Kevin Cheng550ccc52021-03-03 11:21:43 -08002304 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2305 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2306 "transpose_conv2d_TEMPLATE"
2307 ].copy()
2308 self.TOSA_OP_LIST[testName]["filter"] = k
2309 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002310
Kevin Cheng1533b852021-09-01 12:51:58 -07002311 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2312 for k in KERNELS_3D:
2313 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2314 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2315 self.TOSA_OP_LIST[testName]["filter"] = k
2316 self.TOSA_OP_LIST[testName]["template"] = False
2317
Eric Kunzee5e26762020-10-13 16:11:07 -07002318 # Delete any templates after having created any dynamic ops
2319 # This is a two-pass operation because it's bad practice to delete
2320 # keys from dictionaries while iterating
2321 keyList = []
2322 for k in self.TOSA_OP_LIST:
2323 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002324 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07002325 keyList.append(k)
2326 continue
2327 except KeyError:
2328 pass
2329
2330 for k in keyList:
2331 del self.TOSA_OP_LIST[k]
2332
2333 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002334 """Fill in default fields for ops if they aren't already specified.
2335 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002336 for op in self.TOSA_OP_LIST:
2337
2338 # Required fields
2339 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002340 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002341 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002342 raise Exception(
2343 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2344 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002345
2346 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002347 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002348 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002349 raise Exception(
2350 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2351 op
2352 )
2353 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002354
2355 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002356 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002357 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002358 raise Exception(
2359 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2360 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002361
2362 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002363 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002364 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002365 raise Exception(
2366 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2367 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002368
2369 # Put in default rank range, if missing
2370 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002371 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002372 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002373 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002374
2375 # Tensor operator list
2376 # 'op': op name
2377 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002378 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2379 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002380 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2381 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002382 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002383
Kevin Cheng550ccc52021-03-03 11:21:43 -08002384 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2385 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002386
Kevin Cheng550ccc52021-03-03 11:21:43 -08002387 TYPE_BOOL = [DType.BOOL]
2388 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2389 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2390 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002391
Kevin Cheng550ccc52021-03-03 11:21:43 -08002392 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002393
Kevin Cheng1533b852021-09-01 12:51:58 -07002394 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002395 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002396 [DType.INT8, DType.INT8, DType.INT32],
2397 [DType.INT16, DType.INT8, DType.INT48],
2398 DType.FLOAT,
2399 ]
2400
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002401 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002402
2403 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002404 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002405 "argmax": {
2406 "op": Op.ARGMAX,
2407 "operands": (1, 0),
2408 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2409 "types": TYPE_NARROW_INT_FP,
2410 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002411 "avg_pool2d": {
2412 "op": Op.AVG_POOL2D,
2413 "operands": (1, 0),
2414 "rank": (4, 4),
2415 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2416 "qgen": TosaQuantGen.qgUnary,
2417 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002418 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08002419 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002420 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002421 "conv2d_TEMPLATE": {
2422 "op": Op.CONV2D,
2423 "operands": (1, 2),
2424 "rank": (4, 4),
Les Bell7aa69f42021-09-20 10:44:07 +01002425 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002426 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002427 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002428 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002429 "template": True,
2430 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002431 # Templated operator. Filled in by createDynamicOpLists
2432 "conv3d_TEMPLATE": {
2433 "op": Op.CONV3D,
2434 "operands": (1, 2),
2435 "rank": (5, 5),
Les Bell7aa69f42021-09-20 10:44:07 +01002436 "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
Kevin Cheng1533b852021-09-01 12:51:58 -07002437 "qgen": TosaQuantGen.qgConv,
2438 "types": TYPE_CONV,
2439 "template": True,
2440 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002441 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002442 "depthwise_conv2d_TEMPLATE": {
2443 "op": Op.DEPTHWISE_CONV2D,
2444 "operands": (1, 2),
2445 "filter": [1, 1],
2446 "rank": (4, 4),
2447 "build_fcn": (
2448 build_depthwise_conv2d,
2449 TosaTensorGen.tgDepthwiseConv2D,
Les Bell7aa69f42021-09-20 10:44:07 +01002450 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002451 ),
2452 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002453 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002454 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002455 "template": True,
2456 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002457 "fully_connected": {
2458 "op": Op.FULLY_CONNECTED,
2459 "operands": (1, 2),
2460 "rank": (2, 2),
2461 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
2462 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002463 "types": TYPE_CONV,
Jared Smolens573ecd42021-03-04 15:24:10 -08002464 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002465 "matmul": {
2466 "op": Op.MATMUL,
2467 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002468 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08002469 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
2470 "qgen": TosaQuantGen.qgMatmul,
2471 "types": TYPE_NARROW_INT_FP,
2472 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002473 "max_pool2d": {
2474 "op": Op.MAX_POOL2D,
2475 "operands": (1, 0),
2476 "rank": (4, 4),
2477 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2478 "types": TYPE_NARROW_INT_FP,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002479 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
Jared Smolens573ecd42021-03-04 15:24:10 -08002480 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002481 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002482 "transpose_conv2d_TEMPLATE": {
2483 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002484 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002485 "rank": (4, 4),
2486 "build_fcn": (
2487 build_transpose_conv2d,
2488 TosaTensorGen.tgTransposeConv2D,
2489 TosaArgGen.agTransposeConv2D,
2490 ),
2491 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002492 "types": TYPE_CONV,
Matthew Haddonb724efc2021-08-25 16:40:29 +01002493 "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002494 "template": True,
2495 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002496 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002497 "clamp": {
2498 "op": Op.CLAMP,
2499 "operands": (1, 0),
2500 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2501 "types": TYPE_NARROW_INT_FP,
2502 },
2503 "relun": {
2504 "op": Op.RELUN,
2505 "operands": (1, 0),
2506 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2507 "types": TYPE_FI32,
2508 },
2509 "sigmoid": {
2510 "op": Op.SIGMOID,
2511 "operands": (1, 0),
2512 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2513 "types": TYPE_FP,
2514 },
2515 "tanh": {
2516 "op": Op.TANH,
2517 "operands": (1, 0),
2518 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2519 "types": TYPE_FP,
2520 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002521 # Elementwise Binary Operators
2522 "add": {
2523 "op": Op.ADD,
2524 "operands": (2, 0),
2525 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2526 "types": TYPE_FI32,
2527 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002528 "arithmetic_right_shift": {
2529 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2530 "operands": (2, 0),
2531 "build_fcn": (
2532 build_arithmetic_right_shift,
2533 TosaTensorGen.tgBroadcastFuzz,
2534 TosaArgGen.agArithmeticRightShift,
2535 ),
2536 "types": TYPE_INT,
2537 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002538 "bitwise_and": {
2539 "op": Op.BITWISE_AND,
2540 "operands": (2, 0),
2541 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2542 "types": TYPE_INT,
2543 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002544 "bitwise_or": {
2545 "op": Op.BITWISE_OR,
2546 "operands": (2, 0),
2547 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2548 "types": TYPE_INT,
2549 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002550 "bitwise_xor": {
2551 "op": Op.BITWISE_XOR,
2552 "operands": (2, 0),
2553 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2554 "types": TYPE_INT,
2555 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002556 "intdiv": {
2557 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002558 "operands": (2, 0),
2559 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2560 "types": [DType.INT32],
2561 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002562 "logical_and": {
2563 "op": Op.LOGICAL_AND,
2564 "operands": (2, 0),
2565 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2566 "types": TYPE_BOOL,
2567 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002568 "logical_left_shift": {
2569 "op": Op.LOGICAL_LEFT_SHIFT,
2570 "operands": (2, 0),
2571 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2572 "types": TYPE_INT,
2573 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002574 "logical_right_shift": {
2575 "op": Op.LOGICAL_RIGHT_SHIFT,
2576 "operands": (2, 0),
2577 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2578 "types": TYPE_INT,
2579 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002580 "logical_or": {
2581 "op": Op.LOGICAL_OR,
2582 "operands": (2, 0),
2583 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2584 "types": TYPE_BOOL,
2585 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002586 "logical_xor": {
2587 "op": Op.LOGICAL_XOR,
2588 "operands": (2, 0),
2589 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2590 "types": TYPE_BOOL,
2591 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002592 "maximum": {
2593 "op": Op.MAXIMUM,
2594 "operands": (2, 0),
2595 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2596 "types": TYPE_FI32,
2597 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002598 "minimum": {
2599 "op": Op.MINIMUM,
2600 "operands": (2, 0),
2601 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2602 "types": TYPE_FI32,
2603 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002604 "mul": {
2605 "op": Op.MUL,
2606 "operands": (2, 0),
2607 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2608 "types": TYPE_INT_FP,
2609 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002610 "pow": {
2611 "op": Op.POW,
2612 "operands": (2, 0),
2613 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2614 "types": TYPE_FP,
2615 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002616 "sub": {
2617 "op": Op.SUB,
2618 "operands": (2, 0),
2619 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2620 "types": TYPE_FI32,
2621 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002622 "table": {
2623 "op": Op.TABLE,
2624 # Use the automatic generation functions to create the input array
2625 # but create the table tensor in the build function, as it may be
2626 # a different type from the input
2627 "operands": (1, 0),
2628 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002629 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08002630 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002631 # Elementwise Unary operators
2632 "abs": {
2633 "op": Op.ABS,
2634 "operands": (1, 0),
2635 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2636 "types": TYPE_FI32,
2637 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002638 "bitwise_not": {
2639 "op": Op.BITWISE_NOT,
2640 "operands": (1, 0),
2641 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2642 "types": TYPE_INT,
2643 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002644 "ceil": {
2645 "op": Op.CEIL,
2646 "operands": (1, 0),
2647 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2648 "types": TYPE_FP,
2649 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002650 "clz": {
2651 "op": Op.CLZ,
2652 "operands": (1, 0),
2653 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2654 "types": [DType.INT32],
2655 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002656 "exp": {
2657 "op": Op.EXP,
2658 "operands": (1, 0),
2659 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2660 "types": TYPE_FP,
2661 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002662 "floor": {
2663 "op": Op.FLOOR,
2664 "operands": (1, 0),
2665 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2666 "types": TYPE_FP,
2667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002668 "log": {
2669 "op": Op.LOG,
2670 "operands": (1, 0),
2671 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2672 "types": TYPE_FP,
2673 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002674 "logical_not": {
2675 "op": Op.LOGICAL_NOT,
2676 "operands": (1, 0),
2677 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2678 "types": TYPE_BOOL,
2679 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002680 "negate": {
2681 "op": Op.NEGATE,
2682 "operands": (1, 0),
2683 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2684 "qgen": TosaQuantGen.qgUnary,
2685 "types": TYPE_INT_FP,
2686 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002687 "reciprocal": {
2688 "op": Op.RECIPROCAL,
2689 "operands": (1, 0),
2690 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2691 "types": TYPE_FP,
2692 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002693 "rsqrt": {
2694 "op": Op.RSQRT,
2695 "operands": (1, 0),
2696 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2697 "types": TYPE_FP,
2698 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002699 # Elementwise Ternary operators
2700 "select": {
2701 "op": Op.SELECT,
2702 "operands": (3, 0),
2703 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2704 "types": TYPE_FIB,
2705 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002706 # Comparison operators
2707 "equal": {
2708 "op": Op.EQUAL,
2709 "operands": (2, 0),
2710 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2711 "types": TYPE_FI32,
2712 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002713 "greater_equal": {
2714 "op": Op.GREATER_EQUAL,
2715 "operands": (2, 0),
2716 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2717 "types": TYPE_FI32,
2718 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002719 "greater": {
2720 "op": Op.GREATER,
2721 "operands": (2, 0),
2722 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2723 "types": TYPE_FI32,
2724 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002725 # Reduction operators
2726 "reduce_all": {
2727 "op": Op.REDUCE_ALL,
2728 "operands": (1, 0),
2729 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2730 "types": TYPE_BOOL,
2731 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002732 "reduce_any": {
2733 "op": Op.REDUCE_ANY,
2734 "operands": (1, 0),
2735 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2736 "types": TYPE_BOOL,
2737 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002738 "reduce_max": {
2739 "op": Op.REDUCE_MAX,
2740 "operands": (1, 0),
2741 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2742 "types": TYPE_INT_FP,
2743 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002744 "reduce_min": {
2745 "op": Op.REDUCE_MAX,
2746 "operands": (1, 0),
2747 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2748 "types": TYPE_INT_FP,
2749 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002750 "reduce_product": {
2751 "op": Op.REDUCE_PRODUCT,
2752 "operands": (1, 0),
2753 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2754 "types": TYPE_FP,
2755 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002756 "reduce_sum": {
2757 "op": Op.REDUCE_SUM,
2758 "operands": (1, 0),
2759 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2760 "types": TYPE_FI32,
2761 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002762 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002763 "concat": {
2764 "op": Op.CONCAT,
2765 "operands": (2, 0),
Matthew Haddon818ab902021-07-27 09:12:49 +01002766 "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002767 "types": TYPE_FIB,
2768 },
2769 "pad": {
2770 "op": Op.PAD,
2771 "operands": (1, 0),
2772 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2773 "qgen": TosaQuantGen.qgPad,
2774 "types": TYPE_FIB,
2775 },
2776 "reshape": {
2777 "op": Op.RESHAPE,
2778 "operands": (1, 0),
2779 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2780 "types": TYPE_FIB,
2781 },
2782 "reverse": {
2783 "op": Op.REVERSE,
2784 "operands": (1, 0),
2785 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2786 "types": TYPE_FIB,
2787 },
2788 "slice": {
2789 "op": Op.SLICE,
2790 "operands": (1, 0),
2791 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2792 "types": TYPE_FIB,
2793 },
2794 "tile": {
2795 "op": Op.TILE,
2796 "operands": (1, 0),
2797 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2798 "types": TYPE_FIB,
2799 },
2800 "transpose": {
2801 "op": Op.TRANSPOSE,
2802 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01002803 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002804 "build_fcn": (
2805 build_transpose,
2806 TosaTensorGen.tgBasic,
2807 TosaArgGen.agTranspose,
2808 ),
2809 "types": TYPE_FIB,
2810 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002811 # Data nodes
2812 "const": {
2813 "op": Op.CONST,
2814 "operands": (1, 0),
2815 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2816 "types": TYPE_FIB,
2817 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002818 "identity": {
2819 "op": Op.IDENTITY,
2820 "operands": (1, 0),
2821 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2822 "types": TYPE_FIB,
2823 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002824 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002825 "gather": {
2826 "op": Op.GATHER,
2827 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2828 "operands": (1, 0),
2829 "rank": (3, 3),
2830 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2831 "types": TYPE_INT_FP,
2832 },
2833 "scatter": {
2834 "op": Op.SCATTER,
2835 # Only specify 'values_in' tensor here.
2836 #'indices' and 'input' are generated in op building stage
2837 "operands": (2, 0),
2838 "rank": (3, 3),
2839 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2840 "types": TYPE_INT_FP,
2841 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002842 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002843 "resize": {
2844 "op": Op.RESIZE,
2845 "operands": (1, 0),
2846 "rank": (4, 4),
2847 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2848 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Matthew Haddonb724efc2021-08-25 16:40:29 +01002849 "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002850 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002851 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002852 "cast": {
2853 "op": Op.CAST,
2854 "operands": (1, 0),
2855 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2856 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2857 },
2858 "rescale": {
2859 "op": Op.RESCALE,
2860 "operands": (1, 0),
2861 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002862 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002863 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002864 # Custom
2865 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002866 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002867 # Two varients of cond_if, one that generates one of two constant tensors (no
2868 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2869 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002870 "cond_if_const": {
2871 "op": Op.COND_IF,
2872 "operands": (0, 2),
2873 "build_fcn": (
2874 build_cond_if_const,
2875 TosaTensorGen.tgBasic,
2876 TosaArgGen.agCondIf,
2877 ),
2878 "types": [DType.BOOL],
2879 },
2880 "cond_if_binary": {
2881 "op": Op.COND_IF,
2882 "operands": (2, 0),
2883 "build_fcn": (
2884 build_cond_if_binary,
2885 TosaTensorGen.tgBasic,
2886 TosaArgGen.agCondIf,
2887 ),
2888 "types": TYPE_FI32,
2889 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002890 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002891 "while_loop": {
2892 "op": Op.WHILE_LOOP,
2893 "operands": (0, 1),
2894 "build_fcn": (
2895 build_while_loop,
2896 TosaTensorGen.tgBasic,
2897 TosaArgGen.agWhileLoop,
2898 ),
2899 "types": [DType.INT32],
2900 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002901 }
2902
Kevin Cheng550ccc52021-03-03 11:21:43 -08002903
Eric Kunzee5e26762020-10-13 16:11:07 -07002904class OutputShaper:
2905 # Methods in this class compute the expected output shape and datatype
2906 # for common classes of operations
2907 def __init__(self):
2908 pass
2909
2910 # These methods return arguments that can be used for
2911 # creating a new output tensor
2912 @staticmethod
2913 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002914 assert len(a.shape) == len(b.shape)
2915 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002916
2917 shape = []
2918 for i in range(len(a.shape)):
2919 if a.shape[i] == 1:
2920 shape.append(b.shape[i])
2921 else:
2922 shape.append(a.shape[i])
2923
Kevin Cheng550ccc52021-03-03 11:21:43 -08002924 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002925
2926 @staticmethod
2927 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002928 assert len(a.shape) == len(b.shape)
2929 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002930
2931 shape = []
2932 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002933 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002934 shape.append(a.shape[i])
2935
Kevin Cheng550ccc52021-03-03 11:21:43 -08002936 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002937
2938 @staticmethod
2939 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002940 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002941
2942 @staticmethod
2943 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002944 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2945 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002946
2947 shape = []
2948 for i in range(len(a.shape)):
2949 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2950
Kevin Cheng550ccc52021-03-03 11:21:43 -08002951 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002952
2953 @staticmethod
2954 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002955 assert len(a.shape) == len(b.shape)
2956 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002957
2958 # Do broadcast
2959 shape = []
2960 for i in range(len(a.shape)):
2961 if a.shape[i] == 1:
2962 shape.append(b.shape[i])
2963 else:
2964 shape.append(a.shape[i])
2965
2966 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002967 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002968
2969 @staticmethod
2970 def reduceOp(ser, a, axis):
2971
2972 shape = a.shape.copy()
2973
2974 shape[axis] = 1
2975
Kevin Cheng550ccc52021-03-03 11:21:43 -08002976 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002977
2978 @staticmethod
2979 def argmaxOp(ser, a, axis):
2980 shape = a.shape.copy()
2981 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002982 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002983
2984 @staticmethod
2985 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2986
2987 # IFM: NHWC
2988 # Filter: OHWI
2989 # OFM: NHWC
2990
2991 if len(padding) == 2:
2992 # Expand padding to 4 parameters in the case of transpose_conv2d
2993 # From H,W to T,B,L,R
2994 padding = [padding[0], padding[0], padding[1], padding[1]]
2995
Kevin Cheng550ccc52021-03-03 11:21:43 -08002996 h = (
2997 ifm.shape[1]
2998 - filter.shape[1]
2999 - (filter.shape[1] - 1) * (dilations[0] - 1)
3000 + padding[0]
3001 + padding[1]
3002 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003003
Kevin Cheng550ccc52021-03-03 11:21:43 -08003004 w = (
3005 ifm.shape[2]
3006 - filter.shape[2]
3007 - (filter.shape[2] - 1) * (dilations[1] - 1)
3008 + padding[2]
3009 + padding[3]
3010 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003011
Eric Kunzee5e26762020-10-13 16:11:07 -07003012 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3013
Kevin Cheng3a478572021-01-22 17:21:02 -08003014 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003015 out_dtype = DType.INT32
3016 elif ifm.dtype == DType.INT16:
3017 out_dtype = DType.INT48
3018 elif ifm.dtype == DType.FLOAT:
3019 out_dtype = DType.FLOAT
3020 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003021 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003022
Kevin Cheng550ccc52021-03-03 11:21:43 -08003023 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003024
3025 @staticmethod
Kevin Cheng1533b852021-09-01 12:51:58 -07003026 def conv3dOp(ser, ifm, filter, strides, padding, dilations):
3027
3028 # IFM: NDHWC
3029 # Filter: ODHWI
3030 # OFM: NDHWC
3031
3032 d = (
3033 ifm.shape[1]
3034 - filter.shape[1]
3035 - (filter.shape[1] - 1) * (dilations[0] - 1)
3036 + padding[0]
3037 + padding[1]
3038 ) // strides[0] + 1
3039
3040 h = (
3041 ifm.shape[2]
3042 - filter.shape[2]
3043 - (filter.shape[2] - 1) * (dilations[1] - 1)
3044 + padding[2]
3045 + padding[3]
3046 ) // strides[1] + 1
3047
3048 w = (
3049 ifm.shape[3]
3050 - filter.shape[3]
3051 - (filter.shape[3] - 1) * (dilations[2] - 1)
3052 + padding[4]
3053 + padding[5]
3054 ) // strides[2] + 1
3055
3056 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
3057
3058 if ifm.dtype == DType.INT8:
3059 out_dtype = DType.INT32
3060 elif ifm.dtype == DType.INT16:
3061 out_dtype = DType.INT48
3062 elif ifm.dtype == DType.FLOAT:
3063 out_dtype = DType.FLOAT
3064 else:
3065 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
3066
3067 return ser.addOutput(ofm_shape, out_dtype)
3068
3069 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -07003070 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
3071 # IFM: NHWC
3072 # Filter: HWCM
3073 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08003074 h = (
3075 ifm.shape[1]
3076 - filter.shape[0]
3077 - (filter.shape[0] - 1) * (dilations[0] - 1)
3078 + padding[0]
3079 + padding[1]
3080 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003081
Kevin Cheng550ccc52021-03-03 11:21:43 -08003082 w = (
3083 ifm.shape[2]
3084 - filter.shape[1]
3085 - (filter.shape[1] - 1) * (dilations[1] - 1)
3086 + padding[2]
3087 + padding[3]
3088 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003089
Eric Kunzee5e26762020-10-13 16:11:07 -07003090 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
3091
Kevin Cheng3a478572021-01-22 17:21:02 -08003092 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003093 out_dtype = DType.INT32
3094 elif ifm.dtype == DType.INT16:
3095 out_dtype = DType.INT48
3096 elif ifm.dtype == DType.FLOAT:
3097 out_dtype = DType.FLOAT
3098 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003099 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003100
Kevin Cheng550ccc52021-03-03 11:21:43 -08003101 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003102
3103 @staticmethod
3104 def pool2dOp(ser, ifm, kernel, stride, pad):
3105 # input: NHWC
3106 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
3107 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
3108
Eric Kunzee5e26762020-10-13 16:11:07 -07003109 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003110 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003111
3112 @staticmethod
3113 def fullyConnectedOp(ser, input, filter):
3114 # input: N, IC
3115 # filter: OC, IC
3116 # output: N, OC
3117
3118 output_shape = [input.shape[0], filter.shape[0]]
3119
Kevin Cheng3a478572021-01-22 17:21:02 -08003120 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003121 out_dtype = DType.INT32
3122 elif input.dtype == DType.INT16:
3123 out_dtype = DType.INT48
3124 elif input.dtype == DType.FLOAT:
3125 out_dtype = DType.FLOAT
3126 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003127 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003128
Kevin Cheng550ccc52021-03-03 11:21:43 -08003129 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003130
3131 @staticmethod
3132 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07003133 # a: N, H, C
3134 # b: N, C, W
3135 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07003136
Kevin Cheng2d60f002021-06-09 14:18:32 -07003137 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003138
Kevin Cheng3a478572021-01-22 17:21:02 -08003139 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003140 out_dtype = DType.INT32
3141 elif a.dtype == DType.INT16:
3142 out_dtype = DType.INT48
3143 elif a.dtype == DType.FLOAT:
3144 out_dtype = DType.FLOAT
3145 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003146 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003147
Kevin Cheng550ccc52021-03-03 11:21:43 -08003148 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003149
3150 @staticmethod
Matthew Haddon818ab902021-07-27 09:12:49 +01003151 def concatOp(ser, axis, *a):
3152 input1 = a[0]
3153 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07003154
Matthew Haddon818ab902021-07-27 09:12:49 +01003155 output_shape = input1.shape.copy()
Eric Kunzee5e26762020-10-13 16:11:07 -07003156
Matthew Haddon818ab902021-07-27 09:12:49 +01003157 output_shape[axis] = input1.shape[axis]
3158
3159 for tensor in remaining_inputs:
3160 output_shape[axis] += tensor.shape[axis]
3161
3162 return ser.addOutput(output_shape, input1.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003163
3164 @staticmethod
3165 def padOp(ser, a, padding):
3166
3167 output_shape = a.shape.copy()
3168
3169 for i in range(len(output_shape)):
3170 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
3171
Kevin Cheng550ccc52021-03-03 11:21:43 -08003172 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003173
3174 @staticmethod
3175 def reshapeOp(ser, a, shape):
3176 output_shape = shape.copy()
3177
3178 totalElements = 1
3179 for i in a.shape:
3180 totalElements *= i
3181
3182 # If there are any -1 elements, figure out what that dimension must be
3183 totalOutputElements = 1
3184 for i in output_shape:
3185 if i != -1:
3186 totalOutputElements *= i
3187
3188 # And fill it in
3189 for i in range(len(output_shape)):
3190 if output_shape[i] == -1:
3191 output_shape[i] = totalElements // totalOutputElements
3192
Kevin Cheng550ccc52021-03-03 11:21:43 -08003193 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003194
3195 @staticmethod
3196 def sliceOp(ser, a, begin, size):
3197
3198 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003199 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003200
3201 @staticmethod
3202 def tileOp(ser, a, multiples):
3203
3204 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003205 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003206
3207 for i in range(len(output_shape)):
3208 output_shape[i] = a.shape[i] * multiples[i]
3209
Kevin Cheng550ccc52021-03-03 11:21:43 -08003210 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003211
3212 @staticmethod
3213 def transposeOp(ser, a, perms):
3214 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08003215 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07003216
3217 for i in range(len(output_shape)):
3218 output_shape[i] = a.shape[perms[i]]
3219
Kevin Cheng550ccc52021-03-03 11:21:43 -08003220 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003221
3222 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08003223 def gatherOp(ser, values, indices):
3224 assert len(values.shape) == 3
3225 assert len(indices.shape) == 2
3226 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07003227
Kevin Cheng77d0f762020-11-24 10:26:32 -08003228 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
3229
Kevin Cheng550ccc52021-03-03 11:21:43 -08003230 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08003231
3232 @staticmethod
3233 def scatterOp(ser, values_in, indices, input):
3234 assert len(values_in.shape) == 3
3235 assert len(indices.shape) == 2
3236 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08003237 assert values_in.shape[0] == indices.shape[0] # N
3238 assert input.shape[1] == indices.shape[1] # W
3239 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08003240
3241 output_shape = values_in.shape
3242
Kevin Cheng550ccc52021-03-03 11:21:43 -08003243 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003244
3245 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003246 def tableOp(ser, input, table_dtype):
3247 # Same shape as the input, but dtype dependent on table dtype
3248 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
3249 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
3250 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003251
3252 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08003253 def resizeOp(
3254 ser,
3255 input,
3256 mode,
3257 stride,
3258 offset,
3259 shift,
3260 stride_fp,
3261 offset_fp,
3262 output_dims,
3263 input_dtype,
3264 output_dtype,
3265 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003266
3267 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
3268
Kevin Cheng550ccc52021-03-03 11:21:43 -08003269 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003270
3271 @staticmethod
3272 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003273 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003274
3275 @staticmethod
3276 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08003277 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003278 out_dtype = DType.INT32
3279 elif ifm.dtype == DType.INT16:
3280 out_dtype = DType.INT48
3281 elif ifm.dtype == DType.FLOAT:
3282 out_dtype = DType.FLOAT
3283 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003284 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07003285
Kevin Cheng550ccc52021-03-03 11:21:43 -08003286 return ser.addOutput(output_shape, out_dtype)