blob: e375a2ab1e5710755969837c1370fdbd310faba8 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
35
Kevin Cheng550ccc52021-03-03 11:21:43 -080036# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa_serializer as ts
42from tosa_serializer import *
43import tosa
44
45# Convenience variables to the flatc-generated types that should be enums, but aren't
46DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070048ResizeMode = tosa.ResizeMode.ResizeMode()
49
Kevin Cheng550ccc52021-03-03 11:21:43 -080050
Eric Kunzee5e26762020-10-13 16:11:07 -070051class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080052 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
53
Eric Kunzee5e26762020-10-13 16:11:07 -070054 def __init__(self):
55 pass
56
57 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010058 def getQinfo(testGen, dtype):
59 if dtype == DType.INT8:
60 return testGen.randInt(-128, 128)
61 if dtype == DType.UINT8:
62 return testGen.randInt(0, 256)
63 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070064
65 @staticmethod
66 def qgUnary(testGen, op, dtype):
67 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010068 qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
69 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return qinfo
71
72 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010073 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070078 else:
Les Bell30e46802021-07-23 09:43:31 +010079 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
82 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
83 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070084 return qinfo
85
86 @staticmethod
87 def qgMatmul(testGen, op, dtype):
88 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010089 qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
90 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070091 return qinfo
92
93 @staticmethod
94 def qgPad(testGen, op, dtype):
95 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010096 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070097 return qinfo
98
99 @staticmethod
100 def computeMultiplierAndShift(scaleFp, scale32):
101 # Derived from computeMultiplierAndShiftTosaScale32
102 # Provide a floating-point scaling factor and the scale32 parameter
103 # to compute the multiplier and shift
104
105 if scale32:
106 scaleBits = 31
107 else:
108 scaleBits = 15
109
110 m, shift = math.frexp(scaleFp)
111
112 if scaleFp < 0.0:
113 m = -m
114
115 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
118 if multiplier == (1 << scaleBits):
119 multiplier = multiplier // 2
120 shift = shift + 1
121
122 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800123 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 assert multiplier <= (1 << scaleBits)
126 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700127
128 return multiplier, shift
129
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131class TosaTensorGen:
132 """Tensor generators create a shape list for the placeholder and const tensor
133 data operands for the operator. The actual random data is generated separately for each test."""
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 def __init__(self):
136 pass
137
138 @staticmethod
139 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 shape = testGen.makeShape(rank)
142
143 shape_list = []
144 for i in range(pl + const):
145 shape_list.append(shape.copy())
146
147 return shape_list
148
149 @staticmethod
150 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
Kevin Cheng550ccc52021-03-03 11:21:43 -0800153 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155 shape = testGen.makeShape(rank)
156
157 # Constrict the batch size?
158 if testGen.args.max_batch_size:
159 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
160
161 shape_list = []
162 for i in range(pl + const):
163 shape_list.append(shape.copy())
164
165 return shape_list
166
167 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800168 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert pl == 2
172 assert const == 0
173 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800174
175 values_in_shape = testGen.makeShape(rank)
176
177 # Constrict the batch size?
178 if testGen.args.max_batch_size:
179 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
180
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 W = testGen.randInt(
182 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
183 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800184 input_shape = [values_in_shape[0], W, values_in_shape[2]]
185
186 shape_list = []
187 shape_list.append(values_in_shape.copy())
188 shape_list.append(input_shape.copy())
189
190 return shape_list
191
192 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 def tgBroadcastFuzz(testGen, op, rank):
194 shape = testGen.makeShape(rank)
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700197
198 shape_list = []
199
200 # Choose one of the inputs to broadcast
201 bcast_idx = testGen.randInt(0, pl + const)
202 for i in range(pl + const):
203 shape_bcast = shape.copy()
204
205 # If the chosen input, pick a random index to broadcast
206 if i == bcast_idx:
207 fuzz_idx = testGen.randInt(0, rank)
208 shape_bcast[fuzz_idx] = 1
209
210 shape_list.append(shape_bcast)
211
212 return shape_list
213
214 @staticmethod
215 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Kevin Cheng550ccc52021-03-03 11:21:43 -0800218 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 # IFM dimensions are NHWC
221 ifm_shape = testGen.makeShape(rank)
222
223 # Constrict the batch size?
224 if testGen.args.max_batch_size:
225 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
226
227 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
230 # Generate a random OFM depth
231 ofm_depth = testGen.makeShape(1)[0]
232
233 # The filter dimensions are OHWI
234 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
235
236 # The bias is OC
237 bias_shape = np.asarray([ofm_depth])
238
239 return [ifm_shape, filter_shape, bias_shape]
240
241 @staticmethod
242 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700244
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700246
247 # IFM dimensions are NHWC
248 ifm_shape = testGen.makeShape(rank)
249
250 # Constrict the batch size?
251 if testGen.args.max_batch_size:
252 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
253
254 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800255 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700256
257 # Generate a random OFM depth
258 ofm_depth = testGen.makeShape(1)[0]
259
260 # The filter dimensions are OHWI
261 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
262
Kevin Cheng989cb052021-04-28 16:29:44 -0700263 # The bias is OC
264 bias_shape = np.asarray([ofm_depth])
265
266 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700267
268 @staticmethod
269 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 assert rank == 4
273 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 # IFM dimensions are NHWC
276 ifm_shape = testGen.makeShape(rank)
277
278 # Constrict the batch size?
279 if testGen.args.max_batch_size:
280 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
281
282 # Get the filter height/width from the operator parameters
283 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286 # Generate a random OFM depth, but don't let it get too big because
287 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 filter_m = (
289 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
290 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700291
292 # The filter dimensions are HWCM
293 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
294
295 # The bias is M * C
296 bias_shape = np.asarray([ifm_shape[3] * filter_m])
297
298 return [ifm_shape, filter_shape, bias_shape]
299
300 @staticmethod
301 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
306 input_shape = testGen.makeShape(rank)
Les Bell2a29dc62021-07-28 08:04:55 +0100307 filter_oc = (
308 testGen.rng.integers(
309 low=testGen.args.tensor_shape_range[0],
310 high=testGen.args.tensor_shape_range[1],
311 size=1,
312 )[0]
313 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700314 filter_shape = np.asarray([filter_oc, input_shape[1]])
315
316 bias_shape = np.asarray([filter_oc])
317
318 return [input_shape, filter_shape, bias_shape]
319
320 @staticmethod
321 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800322 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700323
Kevin Cheng2d60f002021-06-09 14:18:32 -0700324 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800325 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
327 a_shape = testGen.makeShape(rank)
328 b_oc = testGen.makeShape(1)[0]
Kevin Cheng2d60f002021-06-09 14:18:32 -0700329 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700330
331 return [a_shape, b_shape]
332
Kevin Cheng550ccc52021-03-03 11:21:43 -0800333
Eric Kunzee5e26762020-10-13 16:11:07 -0700334class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800335 """Argument generators create exhaustive or random lists of attributes for operators that take
336 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
337 tuples where the descriptive_name is appended to the test name and the arglist is expanded
338 as arguments to the operator build function."""
339
Eric Kunzee5e26762020-10-13 16:11:07 -0700340 def __init__(self):
341 pass
342
343 @staticmethod
344 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800345 """A trivial argument generator for operators that don't take any
346 non-tensor arguments"""
347 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700348
349 @staticmethod
350 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800351 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700352 axes = []
353
354 shape = shapeList[0]
355
356 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100357 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700358 return axes
359
360 @staticmethod
361 def agConv2D(testGen, opName, shapeList, dtype):
362 arg_list = []
363
364 ifm_shape = shapeList[0]
365 filter_shape = shapeList[1]
366
367 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800368 assert len(ifm_shape) == 4
369 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
371 maxStride = testGen.args.max_conv_stride
372 maxPadding = testGen.args.max_conv_padding + 1
373 maxDilation = testGen.args.max_conv_dilation
374
375 # Strides, padding, dilations
376 for stride in range(0, maxStride ** 2):
377 for padding in range(0, (maxPadding) ** 4):
378 for dilation in range(0, maxDilation ** 2):
379
Kevin Cheng550ccc52021-03-03 11:21:43 -0800380 s = [stride // maxStride + 1, stride % maxStride + 1]
381 p = [
382 (padding // (maxPadding * 4)) % maxPadding,
383 (padding // (maxPadding * 2)) % maxPadding,
384 (padding // (maxPadding * 1)) % maxPadding,
385 padding % maxPadding,
386 ]
387 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700388
389 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800390 arg_list.append(
391 (
392 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
393 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
394 ),
395 [s, p, d],
396 )
397 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700398 return arg_list
399
400 @staticmethod
401 def agTransposeConv2D(testGen, opName, shapeList, dtype):
402 arg_list = []
403
404 ifm_shape = shapeList[0]
405 filter_shape = shapeList[1]
406
407 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800408 assert len(ifm_shape) == 4
409 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700410
411 maxStride = testGen.args.max_conv_stride
412 maxPadding = testGen.args.max_conv_padding + 1
413 maxDilation = testGen.args.max_conv_dilation
414
415 # Strides, padding, dilations
416 for stride in range(0, maxStride ** 2):
417 for out_padding in range(0, (maxPadding) ** 2):
418 for dilation in range(0, maxDilation ** 2):
419
Kevin Cheng550ccc52021-03-03 11:21:43 -0800420 s = [stride // maxStride + 1, stride % maxStride + 1]
421 p = [
422 (out_padding // (maxPadding * 1)) % maxPadding,
423 out_padding % maxPadding,
424 ]
425 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700426
Kevin Cheng550ccc52021-03-03 11:21:43 -0800427 oh = (
428 ifm_shape[1]
429 - filter_shape[1]
430 - (filter_shape[1] - 1) * (d[0] - 1)
431 + 2 * p[0]
432 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
Kevin Cheng550ccc52021-03-03 11:21:43 -0800434 ow = (
435 ifm_shape[2]
436 - filter_shape[2]
437 - (filter_shape[2] - 1) * (d[1] - 1)
438 + 2 * p[1]
439 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700440
441 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800442 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700443
Kevin Cheng550ccc52021-03-03 11:21:43 -0800444 arg_list.append(
445 (
446 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
447 s[0],
448 s[1],
449 p[0],
450 p[1],
451 d[0],
452 d[1],
453 os[0],
454 os[1],
455 os[2],
456 os[3],
457 ),
458 [s, p, d, os],
459 )
460 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700461
462 return arg_list
463
464 @staticmethod
465 def agPad(testGen, opName, shapeList, dtype):
466 arg_list = []
467 rank = len(shapeList[0])
468
Les Bell7ffccce2021-07-28 15:37:02 +0100469 # Exhaustively test combinations of padding on each side of each dimension
470 # - the range of padding values is defined by pad_min and pad_max
471 # - for padding >9, the name format needs to be more distinctive
472 pad_min, pad_max = 0, 1
473 pad_values = [x for x in range(pad_min, pad_max + 1)]
474 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
475 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700476
Les Bell7ffccce2021-07-28 15:37:02 +0100477 for paddings in shape_pad_values:
478 name = "pad"
479 for r in range(rank):
480 before, after = paddings[r]
481 name = f"{name}{before}{after}"
482 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700483
484 return arg_list
485
486 @staticmethod
487 def agPooling(testGen, opName, shapeList, dtype):
488 arg_list = []
489
490 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800491 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700492
493 maxStride = testGen.args.max_pooling_stride
494 maxKernel = testGen.args.max_pooling_kernel
495 maxPadding = testGen.args.max_pooling_padding + 1
496
497 for kernel in range(0, maxKernel ** 2):
498 for stride in range(0, maxStride ** 2):
499 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800500 s = [stride // maxStride + 1, stride % maxStride + 1]
501 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
502 p = [
503 (padding // (maxPadding * 4)) % maxPadding,
504 (padding // (maxPadding * 2)) % maxPadding,
505 (padding // (maxPadding * 1)) % maxPadding,
506 padding % maxPadding,
507 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700508
Kevin Cheng550ccc52021-03-03 11:21:43 -0800509 arg_list.append(
510 (
511 "st{}{}_kern{}{}_pad{}{}{}{}".format(
512 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
513 ),
514 [k, s, p],
515 )
516 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 return arg_list
518
519 @staticmethod
520 def agCast(testGen, opName, shapeList, inDtype):
521 arg_list = []
522
523 # Enumerate the output types here
524 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800525 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700526 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800527 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700528 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800529 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700530 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800531 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700532 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800533 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700534 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800535 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700536
537 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800538 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
540 return arg_list
541
542 @staticmethod
543 def agRescale(testGen, opName, shapeList, inDtype):
544 arg_list = []
545
546 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100547 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
548 if inDtype == DType.UINT8 and dtype != DType.INT8:
549 # The only output dtype for UINT8 is INT8, skip all other combinations
550 continue
551 if inDtype != DType.INT8 and dtype == DType.UINT8:
552 # The only input dtype for UINT8 is INT8, skip all other combinations
553 continue
554
Kevin Cheng550ccc52021-03-03 11:21:43 -0800555 for scale32 in [False, True]:
556 for double_round in [False, True]:
557 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
559 if inDtype == DType.INT48 and scale32:
560 # Illegal condition. Must be scale32=False
561 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100562 if double_round and not scale32:
563 # Illegal condition. ERROR_IF(!scale32 && double_round)
564 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700565
Kevin Cheng550ccc52021-03-03 11:21:43 -0800566 arg_list.append(
567 (
568 "out{}_sc{}_dr{}_pc{}".format(
569 DTypeNames[dtype],
570 int(scale32),
571 int(double_round),
572 int(per_channel),
573 ),
574 [dtype, scale32, double_round, per_channel],
575 )
576 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700577
578 return arg_list
579
Kevin Chengaee1fac2020-11-11 13:54:06 -0800580 @staticmethod
581 def agMul(testGen, opName, shapeList, dtype):
582 arg_list = []
583
584 if dtype is DType.INT32:
585 for p in range(testGen.args.num_rand_permutations):
586
587 shift = testGen.randInt(0, 32)
588
Kevin Cheng550ccc52021-03-03 11:21:43 -0800589 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800590 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100591 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800592
593 return arg_list
594
595 @staticmethod
596 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
597 arg_list = []
598
Kevin Cheng550ccc52021-03-03 11:21:43 -0800599 arg_list.append(("roundTrue", [True]))
600 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800601
602 return arg_list
603
Eric Kunzee5e26762020-10-13 16:11:07 -0700604 # Helper function for reshape. Gets some factors of a larger number.
605 @staticmethod
606 def getFactors(val, start=1):
607 factors = []
608
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100609 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700610 if (val % i) == 0:
611 factors.append(i)
612
613 return factors
614
615 @staticmethod
616 def agReshape(testGen, opName, shapeList, dtype):
617 arg_list = []
618
619 origShape = shapeList[0]
620
621 totalElements = 1
622 for s in origShape:
623 totalElements *= s
624
625 # This code is NOT fast. Fortunately, the numbers are fairly small.
626 factors = TosaArgGen.getFactors(totalElements)
627
628 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100629 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800630 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700631 continue
632
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100633 found = True
634 # escape_counter breaks while loop if it continues on for too long
635 escape_counter = 0
636 while found:
637 newShape = []
638 # Generate newShape ensuring it isn't a duplicate
639 remainingElements = totalElements
640 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100641 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100642 # pick rank-1 factors
643 newShape.append(shuffledFactors[0])
644 remainingElements = remainingElements // shuffledFactors[0]
645 shuffledFactors = testGen.rng.permutation(
646 TosaArgGen.getFactors(remainingElements)
647 )
648 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700649
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100650 # Toss in a -1 sometimes
651 minusOne = testGen.randInt(0, newRank * 4)
652 if minusOne < newRank:
653 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700654
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100655 # Check for duplicates
656 found = False
657 for name, other_shape in arg_list:
658 if other_shape[0] == newShape:
659 found = True
660 break
661
662 escape_counter += 1
663 if escape_counter >= 100:
664 break
665
666 if not found:
667 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700668
669 return arg_list
670
Eric Kunzee5e26762020-10-13 16:11:07 -0700671 @staticmethod
672 def agTranspose(testGen, opName, shapeList, dtype):
673 arg_list = []
674
675 ifm_shape = shapeList[0]
676
Jeremy Johnsona6185572021-06-21 15:55:35 +0100677 # Get all permutations
678 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700679
Jeremy Johnsona6185572021-06-21 15:55:35 +0100680 # Limit to possible permutations from shape dimension or argument setting
681 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700682
Jeremy Johnsona6185572021-06-21 15:55:35 +0100683 # Get random permutation generator that uses all permutations
684 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700685
Jeremy Johnsona6185572021-06-21 15:55:35 +0100686 # Create list of required amount of permutations
687 arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700688 return arg_list
689
690 @staticmethod
691 def agSlice(testGen, opName, shapeList, dtype):
692 arg_list = []
693
694 ifm_shape = shapeList[0]
695 rank = len(ifm_shape)
696
697 for p in range(testGen.args.num_rand_permutations):
698 begin = []
699 size = []
700
Kevin Cheng550ccc52021-03-03 11:21:43 -0800701 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700702
703 for i in range(rank):
704 if ifm_shape[i] > 1:
705 begin.append(testGen.randInt(0, ifm_shape[i]))
706 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
707
708 # Invalid slice size?
709 if size[i] == 0:
710 valid = False
711 else:
712 begin.append(0)
713 size.append(1)
714
715 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800716 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700717 return arg_list
718
719 @staticmethod
720 def agTile(testGen, opName, shapeList, dtype):
721 arg_list = []
722
723 ifm_shape = shapeList[0]
724 rank = len(ifm_shape)
725
726 for p in range(testGen.args.num_rand_permutations):
727
728 # Pick a few random, but small multiple values
729 # because otherwise this has a tendency to generate
730 # enormous tensors
731 multiples = []
732 for i in range(rank):
733 multiples.append(testGen.randInt(1, 4))
734
Kevin Cheng550ccc52021-03-03 11:21:43 -0800735 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700736
737 return arg_list
738
739 @staticmethod
740 def agResize(testGen, opName, shapeList, dtype):
741 arg_list = []
742
743 ifm_shape = shapeList[0]
744
745 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
746
747 # Exclude illegal {mode, type} configurations. Pick legal output types
748 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100749 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700750 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800751 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700752 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100753 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700754 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800755 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800756 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800757 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700758 else:
759 continue
760
761 for outputDType in outputDTypeList:
762 for perm in range(testGen.args.num_rand_permutations):
763
764 # Randomly generate legal output dimensions and shift
765 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800766 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800767 in_center_h = (ifm_shape[1] - 1) / 2.0
768 in_center_w = (ifm_shape[2] - 1) / 2.0
769 out_center_h = (output_dims[0] - 1) / 2.0
770 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700771
Kevin Cheng77d0f762020-11-24 10:26:32 -0800772 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
773 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
774 fp_offset_y = in_center_h - fp_stride_y * out_center_h
775 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700776
Kevin Cheng77d0f762020-11-24 10:26:32 -0800777 if outputDType == DType.FLOAT:
778 shift = 0
779 stride = [0, 0]
780 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800781 stride_fp = [fp_stride_y, fp_stride_x]
782 offset_fp = [fp_offset_y, fp_offset_x]
783 arg_list.append(
784 (
785 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100786 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800787 output_dims[0],
788 output_dims[1],
789 testGen.typeStr(outputDType),
790 stride_fp[0],
791 stride_fp[1],
792 offset_fp[0],
793 offset_fp[1],
794 ),
795 [
796 m,
797 stride,
798 offset,
799 shift,
800 stride_fp,
801 offset_fp,
802 output_dims,
803 dtype,
804 outputDType,
805 ],
806 )
807 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800808 else:
809 shift = 11
810 unit = float(1 << shift)
811 stride_y = int(round(fp_stride_y * unit))
812 stride_x = int(round(fp_stride_x * unit))
813 offset_y = int(round(fp_offset_y * unit))
814 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700815
Kevin Cheng550ccc52021-03-03 11:21:43 -0800816 while (
817 stride_y >= 32768
818 or stride_x >= 32768
819 or offset_y >= 32768
820 or offset_x >= 32768
821 or offset_y < -32768
822 or offset_x < -32768
823 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800824 shift = shift - 1
825 unit = float(1 << shift)
826 stride_y = int(round(fp_stride_y * unit))
827 stride_x = int(round(fp_stride_x * unit))
828 offset_y = int(round(fp_offset_y * unit))
829 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
Kevin Cheng550ccc52021-03-03 11:21:43 -0800831 stride = [stride_y, stride_x]
832 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800833
834 stride_fp = [0.0, 0.0]
835 offset_fp = [0.0, 0.0]
836
Kevin Cheng550ccc52021-03-03 11:21:43 -0800837 arg_list.append(
838 (
839 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100840 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800841 shift,
842 output_dims[0],
843 output_dims[1],
844 testGen.typeStr(outputDType),
845 stride[0],
846 stride[1],
847 offset[0],
848 offset[1],
849 ),
850 [
851 m,
852 stride,
853 offset,
854 shift,
855 stride_fp,
856 offset_fp,
857 output_dims,
858 dtype,
859 outputDType,
860 ],
861 )
862 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700863
864 return arg_list
865
866 def agCondIf(testGen, opName, shapeList, dtype):
867 # CondIf generates the condition values here.
868 # Convert to tensors in the build function, along with the
869 # then and else blocks
870 arg_list = []
871
872 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800873 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700874
875 return arg_list
876
877 def agWhileLoop(testGen, opName, shapeList, dtype):
878 # While loop: 0 iterations, 1, more than 1
879 arg_list = []
880
881 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800882 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700883
884 return arg_list
885
Kevin Cheng550ccc52021-03-03 11:21:43 -0800886
Eric Kunzee5e26762020-10-13 16:11:07 -0700887class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +0100888 # Maximum rank of tensor supported by test generator.
889 TOSA_TENSOR_MAX_RANK = 6
890
Eric Kunzee5e26762020-10-13 16:11:07 -0700891 def __init__(self, args):
892 self.args = args
893 self.basePath = args.output_dir
894 self.random_seed = args.random_seed
895 self.ser = None
896 self.rng = np.random.default_rng(self.random_seed)
897 self.createDynamicOpLists()
898 self.initOpListDefaults()
899 self.quantGen = TosaQuantGen()
900 # Force makeShape to do a specific starting shape
901 self.targetted_shape = None
902
903 def createSerializer(self, opName, testPath):
904 self.testPath = os.path.join(opName, testPath)
905
906 fullPath = os.path.join(self.basePath, self.testPath)
907 os.makedirs(fullPath, exist_ok=True)
908 self.ser = ts.TosaSerializer(fullPath)
909
910 def getSerializer(self):
911 return self.ser
912
913 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800914 with open(
915 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
916 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700917 fd.write(self.ser.serialize())
918
Kevin Cheng550ccc52021-03-03 11:21:43 -0800919 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
920 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700921
922 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700923 if dtype == DType.BOOL:
924 np_dt = np.bool
925 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700926 elif dtype == DType.INT4:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100927 return np.int32(self.rng.integers(low=-8, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700928 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100929 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
930 elif dtype == DType.UINT8:
931 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700932 elif dtype == DType.INT16:
933 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
934 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800935 return np.int32(
936 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
937 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700938 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800939 return np.int64(
940 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
941 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700942 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100943 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700944 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800945 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700946
Kevin Cheng989cb052021-04-28 16:29:44 -0700947 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700948 placeholders = []
949
Kevin Cheng989cb052021-04-28 16:29:44 -0700950 assert len(shape_list) == len(dtype_list)
951
952 for idx, shape in enumerate(shape_list):
953 arr = self.getRandTensor(shape, dtype_list[idx])
954 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700955
956 return placeholders
957
Kevin Cheng989cb052021-04-28 16:29:44 -0700958 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700959 consts = []
960
Kevin Cheng989cb052021-04-28 16:29:44 -0700961 assert len(shape_list) == len(dtype_list)
962
963 for idx, shape in enumerate(shape_list):
964 arr = self.getRandTensor(shape, dtype_list[idx])
965 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700966
967 return consts
968
969 def makeShape(self, rank):
970 if self.targetted_shape:
971 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800972 return np.int32(
973 self.rng.integers(
974 low=self.args.tensor_shape_range[0],
975 high=self.args.tensor_shape_range[1],
976 size=rank,
977 )
978 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700979
980 def setTargetShape(self, shape):
981 self.targetted_shape = shape
982
983 def randInt(self, low=0, high=256):
984 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
985
986 def getRandNumberDType(self, dtype):
987 if dtype == DType.FLOAT:
988 return self.rng.random()
989 elif dtype == DType.BOOL:
990 return self.rng.choice([False, True])
991 elif dtype == DType.INT4:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100992 low, high = (-8, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700993 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100994 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 elif dtype == DType.INT16:
996 low, high = (-32768, 32768)
997 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800998 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700999 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001000 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001001 # Special size
1002 return np.int64(self.rng.integers(low, high, size=1))[0]
1003 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001004 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001005
1006 return np.int32(self.rng.integers(low, high, size=1))[0]
1007
1008 def shapeStr(self, shape):
1009
1010 sStr = []
1011 # Convert to strings
1012 for i in shape:
1013 sStr.append(str(i))
1014
Kevin Cheng550ccc52021-03-03 11:21:43 -08001015 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001016
1017 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001018 if isinstance(t, list):
1019 assert len(t) >= 2
1020 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001021 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001022 if t == DType.BOOL:
1023 return "b"
1024 elif t == DType.INT4:
1025 return "i4"
1026 elif t == DType.INT8:
1027 return "i8"
1028 elif t == DType.UINT8:
1029 return "u8"
1030 elif t == DType.INT16:
1031 return "i16"
1032 elif t == DType.INT32:
1033 return "i32"
1034 elif t == DType.INT48:
1035 return "i48"
1036 elif t == DType.FLOAT:
1037 return "float"
1038 else:
1039 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001040
1041 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001042 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001043 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001044 return 4
1045 elif t == DType.INT8:
1046 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001047 elif t == DType.UINT8:
1048 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001049 elif t == DType.INT16:
1050 return 16
1051 elif t == DType.INT32:
1052 return 32
1053 elif t == DType.INT48:
1054 return 48
1055 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001056 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001057
1058 # Argument generators
1059 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1060 # Where the string descriptor is used to generate the test name and
1061 # The build_fcn_arg_list is expanded and passed to the operator test
1062 # build function
1063
Kevin Cheng550ccc52021-03-03 11:21:43 -08001064 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001065 result_tens = OutputShaper.unaryOp(self.ser, a)
1066 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1067 return result_tens
1068
1069 def build_binary_broadcast(self, op, a, b):
1070 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1071 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1072 return result_tens
1073
1074 def build_binary_nonbroadcast(self, op, a, b):
1075 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1076 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1077 return result_tens
1078
Kevin Chengaee1fac2020-11-11 13:54:06 -08001079 def build_arithmetic_right_shift(self, op, a, b, round):
1080 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1081
1082 attr = ts.TosaSerializerAttribute()
1083 attr.ArithmeticRightShiftAttribute(round)
1084
1085 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1086 return result_tens
1087
1088 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001089 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1090
1091 # Special for multiply:
1092 # Force the result to INT32 for INT types
1093 if a.dtype != DType.FLOAT:
1094 result_tens.setDtype(DType.INT32)
1095
Kevin Chengaee1fac2020-11-11 13:54:06 -08001096 attr = ts.TosaSerializerAttribute()
1097 attr.MulAttribute(shift)
1098
1099 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001100 return result_tens
1101
1102 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001103 # Constant size depending on type, random values
1104 if a.dtype == DType.INT16:
1105 table_dtype = DType.INT16
1106 table_arr = self.getRandTensor([513], table_dtype)
1107 else:
1108 assert a.dtype == DType.INT8
1109 table_dtype = DType.INT8
1110 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001112 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1113 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001114 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1115
1116 return result_tens
1117
1118 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001119 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1120 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001121 return result_tens
1122
1123 def build_comparison(self, op, a, b):
1124 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1125 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1126 return result_tens
1127
1128 def build_argmax(self, op, a, axis):
1129 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1130
1131 attr = ts.TosaSerializerAttribute()
1132 attr.AxisAttribute(axis)
1133
1134 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1135 return result_tens
1136
Kevin Cheng550ccc52021-03-03 11:21:43 -08001137 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001138 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1139
1140 attr = ts.TosaSerializerAttribute()
1141 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001142
1143 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1144 return result_tens
1145
1146 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001147 assert len(padding) == 4
1148 result_tens = OutputShaper.conv2dOp(
1149 self.ser, ifm, filter, strides, padding, dilations
1150 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001151
1152 attr = ts.TosaSerializerAttribute()
1153 attr.Conv2dAttribute(padding, strides, dilations)
1154
Kevin Cheng550ccc52021-03-03 11:21:43 -08001155 self.ser.addOperator(
1156 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1157 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001158 return result_tens
1159
Kevin Cheng550ccc52021-03-03 11:21:43 -08001160 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001161 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001162 ):
1163 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001164 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1165
1166 attr = ts.TosaSerializerAttribute()
1167 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1168
Kevin Cheng550ccc52021-03-03 11:21:43 -08001169 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001170 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001171 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001172 return result_tens
1173
Kevin Cheng550ccc52021-03-03 11:21:43 -08001174 def build_depthwise_conv2d(
1175 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1176 ):
1177 result_tens = OutputShaper.depthwiseConv2dOp(
1178 self.ser, ifm, filter, strides, padding, dilations
1179 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001180
1181 attr = ts.TosaSerializerAttribute()
1182 attr.Conv2dAttribute(padding, strides, dilations)
1183
Kevin Cheng550ccc52021-03-03 11:21:43 -08001184 self.ser.addOperator(
1185 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1186 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001187 return result_tens
1188
1189 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1190 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1191
Kevin Cheng550ccc52021-03-03 11:21:43 -08001192 self.ser.addOperator(
1193 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1194 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001195 return result_tens
1196
1197 def build_matmul(self, op, a, b, qinfo):
1198 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1199 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1200 return result_tens
1201
1202 def build_reduce(self, op, a, axis):
1203 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1204
1205 attr = ts.TosaSerializerAttribute()
1206 attr.AxisAttribute(axis)
1207
1208 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1209 return result_tens
1210
1211 def build_clamp(self, op, a):
1212 result_tens = OutputShaper.unaryOp(self.ser, a)
1213
1214 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01001215 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001216
1217 if a.dtype == DType.FLOAT:
1218 attr.ClampAttribute(0, 0, min(v), max(v))
1219 else:
1220 attr.ClampAttribute(min(v), max(v), 0, 0)
1221
1222 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1223 return result_tens
1224
1225 def build_leaky_relu(self, op, a):
1226 result_tens = OutputShaper.unaryOp(self.ser, a)
1227 attr = ts.TosaSerializerAttribute()
1228
1229 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1230
1231 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1232 return result_tens
1233
1234 # Needs an additional type/input
1235 def build_prelu(self, op, a):
1236 result_tens = OutputShaper.unaryOp(self.ser, a)
1237
1238 self.ser.addOperator(op, [a.name], [result_tens.name])
1239 return result_tens
1240
1241 def build_relun(self, op, a):
1242 result_tens = OutputShaper.unaryOp(self.ser, a)
1243
1244 attr = ts.TosaSerializerAttribute()
1245
1246 if a.dtype == DType.FLOAT:
1247 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1248 else:
1249 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1250
1251 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1252 return result_tens
1253
1254 def build_sigmoid(self, op, a):
1255 result_tens = OutputShaper.unaryOp(self.ser, a)
1256 self.ser.addOperator(op, [a.name], [result_tens.name])
1257 return result_tens
1258
1259 def build_tanh(self, op, a):
1260 result_tens = OutputShaper.unaryOp(self.ser, a)
1261 self.ser.addOperator(op, [a.name], [result_tens.name])
1262 return result_tens
1263
1264 def build_concat(self, op, a, b, axis):
1265 result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
1266
1267 attr = ts.TosaSerializerAttribute()
1268 attr.AxisAttribute(axis)
1269
1270 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1271
1272 def build_pad(self, op, a, padding, qinfo):
1273 result_tens = OutputShaper.padOp(self.ser, a, padding)
1274
1275 # Need to turn the padding array into a TOSA tensor here.
1276 # This is one of the few tensor operands that does not get
1277 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001278 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001279
Kevin Cheng550ccc52021-03-03 11:21:43 -08001280 self.ser.addOperator(
1281 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1282 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001283
1284 def build_reshape(self, op, a, newShape):
1285 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1286
1287 attr = ts.TosaSerializerAttribute()
1288 attr.ReshapeAttribute(newShape)
1289
1290 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1291 return result_tens
1292
1293 def build_reverse(self, op, a, axis):
1294 result_tens = OutputShaper.unaryOp(self.ser, a)
1295
1296 attr = ts.TosaSerializerAttribute()
1297 attr.AxisAttribute(axis)
1298
1299 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1300 return result_tens
1301
1302 def build_transpose(self, op, a, perms):
1303 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1304
Kevin Cheng550ccc52021-03-03 11:21:43 -08001305 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001306
1307 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1308 return result_tens
1309
1310 def build_slice(self, op, a, begin, size):
1311 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1312
1313 attr = ts.TosaSerializerAttribute()
1314 attr.SliceAttribute(begin, size)
1315
1316 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1317 return result_tens
1318
1319 def build_tile(self, op, a, multiples):
1320 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1321
1322 attr = ts.TosaSerializerAttribute()
1323 attr.TileAttribute(multiples)
1324
1325 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1326 return result_tens
1327
Kevin Cheng77d0f762020-11-24 10:26:32 -08001328 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001329
1330 # Create a new indicies tensor
1331 # here with data that doesn't exceed the dimensions of the values tensor
1332
Kevin Cheng550ccc52021-03-03 11:21:43 -08001333 K = values.shape[1] # K
1334 W = self.randInt(
1335 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1336 ) # W
1337 indicies_arr = np.int32(
1338 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1339 ) # (N, W)
1340 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001341
Kevin Cheng77d0f762020-11-24 10:26:32 -08001342 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001343
Kevin Cheng77d0f762020-11-24 10:26:32 -08001344 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001345
1346 return result_tens
1347
Kevin Cheng77d0f762020-11-24 10:26:32 -08001348 def build_scatter(self, op, values_in, input):
1349
1350 # Create a new indicies tensor
1351 # here with data that doesn't exceed the dimensions of the values_in tensor
1352
Kevin Cheng550ccc52021-03-03 11:21:43 -08001353 K = values_in.shape[1] # K
1354 W = input.shape[1] # W
1355 indicies_arr = np.int32(
1356 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1357 ) # (N, W)
1358 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001359
1360 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1361
Kevin Cheng550ccc52021-03-03 11:21:43 -08001362 self.ser.addOperator(
1363 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1364 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001365
1366 return result_tens
1367
Kevin Cheng550ccc52021-03-03 11:21:43 -08001368 def build_resize(
1369 self,
1370 op,
1371 input,
1372 mode,
1373 stride,
1374 offset,
1375 shift,
1376 stride_fp,
1377 offset_fp,
1378 output_dims,
1379 input_dtype,
1380 output_dtype,
1381 ):
1382 result_tens = OutputShaper.resizeOp(
1383 self.ser,
1384 input,
1385 mode,
1386 stride,
1387 offset,
1388 shift,
1389 stride_fp,
1390 offset_fp,
1391 output_dims,
1392 input_dtype,
1393 output_dtype,
1394 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001395
1396 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001397
Kevin Cheng550ccc52021-03-03 11:21:43 -08001398 attr.ResizeAttribute(
1399 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1400 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001401
1402 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1403 return result_tens
1404
1405 def build_identityn(self, op, val, val2):
1406
Kevin Cheng550ccc52021-03-03 11:21:43 -08001407 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001408 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001409 self.ser.addOperator(
1410 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1411 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001412 return result_tens
1413
1414 def build_placeholder(self, op, val):
1415 # Add an identity op to avoid warning in the reference model
1416 return self.build_unary(Op.IDENTITY, val)
1417
1418 # Type Conversion
1419 def build_cast(self, op, val, out_dtype):
1420 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1421 self.ser.addOperator(op, [val.name], [result_tens.name])
1422 return result_tens
1423
1424 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1425 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1426
1427 if per_channel:
1428 nc = val.shape[-1]
1429 else:
1430 nc = 1
1431
1432 in_type_width = self.typeWidth(val.dtype)
1433 out_type_width = self.typeWidth(out_dtype)
1434
Kevin Cheng3a478572021-01-22 17:21:02 -08001435 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001436 input_zp = self.randInt(-128, 128)
1437 in_type_width = in_type_width + 1
1438 elif val.dtype == DType.UINT8:
1439 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001440 in_type_width = in_type_width + 1
1441 else:
1442 input_zp = 0
1443
Kevin Cheng3a478572021-01-22 17:21:02 -08001444 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001445 output_zp = self.randInt(-128, 128)
1446 out_type_width = out_type_width + 1
1447 elif out_dtype == DType.UINT8:
1448 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001449 out_type_width = out_type_width + 1
1450 else:
1451 output_zp = 0
1452
1453 # Calculate scale based on:
1454 # scale = a *(2^output_width)/(2^input_width))
1455
1456 a = np.float32(self.rng.random(size=[nc]))
1457 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1458
1459 if scale32:
1460 pass
1461 # Cap the scaling at 2^15 - 1 for scale16
1462 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1463 else:
1464 # Cap the scaling at 2^15 - 1 for scale16
1465 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1466
Kevin Cheng550ccc52021-03-03 11:21:43 -08001467 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001468
1469 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1470 shift_arr = np.int32(np.zeros(shape=[nc]))
1471
1472 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001473 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1474 scale_arr[i], scale32
1475 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001476 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001477 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001478
Kevin Cheng550ccc52021-03-03 11:21:43 -08001479 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001480
1481 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001482 attr.RescaleAttribute(
1483 input_zp,
1484 output_zp,
1485 multiplier_arr,
1486 shift_arr,
1487 scale32,
1488 double_round,
1489 per_channel,
1490 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001491
1492 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1493 return result_tens
1494
1495 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1496 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1497 # (except for the generated shap) and the condition. Build Then/Else blocks
1498 # and fill them with const nodes for the body.
1499
1500 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001501 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001502
1503 # Make then/else tensors
1504 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01001505 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1506 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001507
1508 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001509 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001510
1511 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001512 then_block = "THEN_BLOCK"
1513 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001514 attr = ts.TosaSerializerAttribute()
1515 attr.CondIfAttribute(then_block, else_block)
1516
1517 # Finally, build the op and the two blocks
1518 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1519
1520 self.ser.startBasicBlock(then_block)
1521 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001522 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001523 self.ser.addOutputTensor(then_tens)
1524
1525 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001526 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001527 self.ser.addOutputTensor(else_tens)
1528
1529 return result_tens
1530
1531 def build_cond_if_binary(self, op, a, b, cond):
1532 # For cond_if with a binary op in the then/else blocks, take a and b and
1533 # alternately add or subtract them based on the condition
1534
1535 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001536 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001537
Kevin Cheng550ccc52021-03-03 11:21:43 -08001538 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001539 self.ser.currBasicBlock.addOutput(result_tens.name)
1540
1541 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001542 then_block = "THEN_BLOCK"
1543 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001544 attr = ts.TosaSerializerAttribute()
1545 attr.CondIfAttribute(then_block, else_block)
1546
1547 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001548 self.ser.addOperator(
1549 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1550 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001551
1552 self.ser.startBasicBlock(then_block)
1553 self.ser.addInputTensor(a)
1554 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001555 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001556 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1557
1558 self.ser.startBasicBlock(else_block)
1559 self.ser.addInputTensor(a)
1560 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001561 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001562 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1563
1564 return result_tens
1565
1566 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001567 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001568
Kevin Cheng550ccc52021-03-03 11:21:43 -08001569 cond_block = "COND_BLOCK"
1570 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001571
1572 attr = ts.TosaSerializerAttribute()
1573 attr.WhileLoopAttribute(cond_block, body_block)
1574
1575 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001576 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001577 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001578 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001579
1580 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001581 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1582 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1583 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001584
1585 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001586 self.ser.addOperator(
1587 op,
1588 [iter.name, a.name, acc.name],
1589 [iter_out.name, a_out.name, acc_out.name],
1590 attr,
1591 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001592
1593 # COND block (input: iter, output: cond_tens )
1594 self.ser.startBasicBlock(cond_block)
1595 self.ser.addInputTensor(iter)
1596 self.ser.addInputTensor(a)
1597 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001598 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1599 cond_tens = self.ser.addOutput([], DType.BOOL)
1600 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001601
1602 # BODY block (input: a, acc, iter, output: a, acc, iter)
1603 # Note that local intermediate tensors need to be declared here for the outputs
1604 self.ser.startBasicBlock(body_block)
1605 self.ser.addInputTensor(iter)
1606 self.ser.addInputTensor(a)
1607 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001608 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1609 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1610 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001611 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1612 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1613 self.ser.addOutputTensor(iter_body_out)
1614 self.ser.addOutputTensor(a)
1615 self.ser.addOutputTensor(acc_body_out)
1616
1617 return acc_out
1618
Kevin Cheng550ccc52021-03-03 11:21:43 -08001619 def genOpTestList(
1620 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1621 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001622
1623 try:
1624 op = self.TOSA_OP_LIST[opName]
1625 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001626 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001627
1628 # Initialize a new random number generator
1629 self.rng = np.random.default_rng(self.random_seed)
1630
Kevin Cheng550ccc52021-03-03 11:21:43 -08001631 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001632
1633 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001634 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001635
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001636 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1637 default_test_rank_range = range(1, 5)
1638
Eric Kunzee5e26762020-10-13 16:11:07 -07001639 # Test list consists of a tuple of:
1640 # (opName, testNameStr, dtype, shapeList, argumentsList)
1641 testList = []
1642
1643 if not shapeFilter:
1644 shapeFilter = [None]
1645
1646 for r in range(rmin, rmax + 1):
1647
1648 # Filter out the rank?
1649 if rankFilter is not None and r not in rankFilter:
1650 continue
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001651 if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
1652 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001653
Kevin Cheng550ccc52021-03-03 11:21:43 -08001654 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001655
1656 # Filter tests based on dtype?
1657 if dtypeFilter is not None:
Les Bell30e46802021-07-23 09:43:31 +01001658 if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
Eric Kunzee5e26762020-10-13 16:11:07 -07001659 continue
1660
1661 # Create the placeholder and const tensors
1662 for shape in shapeFilter:
1663 # A None shape chooses a random shape of a given rank
1664
1665 # Filter out by rank
1666 if shape is not None and len(shape) != r:
1667 continue
1668
1669 self.setTargetShape(shape)
1670 shapeList = tgen_fcn(self, op, r)
1671
1672 shapeStr = self.shapeStr(shapeList[0])
1673 typeStr = self.typeStr(t)
1674
1675 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1676 argList = []
1677 if agen_fcn:
1678 argList = agen_fcn(self, opName, shapeList, t)
1679 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001680 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001681
1682 for argStr, args in argList:
1683 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001684 testStr = "{}_{}_{}_{}".format(
1685 opName, shapeStr, typeStr, argStr
1686 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001687 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001688 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001689
1690 testList.append((opName, testStr, t, shapeList, args))
1691
1692 return testList
1693
Kevin Cheng989cb052021-04-28 16:29:44 -07001694 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001695 try:
1696 op = self.TOSA_OP_LIST[opName]
1697 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001698 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001699
1700 # Create a serializer
1701 self.createSerializer(opName, testStr)
1702
Kevin Cheng550ccc52021-03-03 11:21:43 -08001703 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1704 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001705 num_operands = pCount + cCount
1706
1707 if isinstance(dtype_or_dtypeList, list):
1708 dtypeList = dtype_or_dtypeList
1709 else:
1710 dtypeList = [dtype_or_dtypeList] * (num_operands)
1711
1712 assert (
1713 len(shapeList) == num_operands
1714 ), "shapeList length {} must match number of operands {}".format(
1715 len(shapeList), num_operands
1716 )
1717 assert (
1718 len(dtypeList) == num_operands
1719 ), "dtypeList length {} must match number of operands {}".format(
1720 len(dtypeList), num_operands
1721 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001722
1723 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001724 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001725 except KeyError:
1726 qgen = None
1727
1728 # Build the random tensor operands and the test
1729 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001730
1731 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001732 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1733 assert (
1734 pCount == 2 and cCount == 0
1735 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001736
1737 placeholders = []
1738 for idx, shape in enumerate(shapeList[:]):
1739 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001740 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001741 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001742 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001743 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001744 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001745 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1746 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001747 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001748 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001749 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07001750 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001751
1752 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01001753 elif op["op"] == Op.SELECT:
1754 # Set datatype of condition tensor to boolean
1755 dtypeList[0] = DType.BOOL
1756 tens.extend(
1757 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1758 )
1759 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001760 elif op["op"] == Op.DIV:
1761 assert (
1762 pCount == 2 and cCount == 0
1763 ), "Op.Div must have 2 placeholders, 0 consts"
1764
1765 placeholders = []
1766
1767 # Two invalid cases for Op.DIV:
1768 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07001769 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001770 while True:
1771 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
1772 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
1773
1774 if (divisor_arr == 0).any():
1775 continue
1776
Kevin Cheng47315e12021-05-13 17:41:28 -07001777 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001778 continue
1779
1780 break
1781
1782 placeholders.append(
1783 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1784 )
1785 placeholders.append(
1786 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1787 )
1788
1789 tens.extend(placeholders)
1790 elif op["op"] == Op.MUL:
1791 assert (
1792 pCount == 2 and cCount == 0
1793 ), "Op.MUL must have 2 placeholders, 0 consts"
1794
1795 if dtypeList[0] == DType.FLOAT:
1796 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
1797 else:
1798 placeholders = []
1799
1800 # Make sure multiply result in int32 range
1801 shift = testArgs[0]
1802 if dtypeList[0] == DType.INT8:
1803 num_bits = 8
1804 elif dtypeList[0] == DType.INT16:
1805 num_bits = 16
1806 elif dtypeList[0] == DType.INT32:
1807 num_bits = 32
1808 else:
1809 raise Exception("OpMul: invalid input dtype")
1810
1811 for idx, shape in enumerate(shapeList[:]):
1812 low = -(2 ** (num_bits - 1))
1813 high = (2 ** (num_bits - 1)) - 1
1814
1815 a_arr = np.int32(
1816 self.rng.integers(low=low, high=high, size=shapeList[0])
1817 )
1818 b_arr = np.int32(
1819 self.rng.integers(low=low, high=high, size=shapeList[1])
1820 )
1821
1822 i = 0
1823 while True:
1824
1825 a_arr_64 = a_arr.astype(np.int64)
1826 b_arr_64 = b_arr.astype(np.int64)
1827
1828 if shift > 0:
1829 rounding = 1 << (shift - 1)
1830 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
1831 else:
1832 result_arr = a_arr_64 * b_arr_64
1833
1834 if (result_arr > -(2 ** 31)).all() and (
1835 result_arr <= ((2 ** 31) - 1)
1836 ).all():
1837 break
1838
1839 i = i + 1
1840 a_arr = a_arr // 2
1841 b_arr = b_arr // 2
1842
1843 placeholders.append(
1844 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1845 )
1846 placeholders.append(
1847 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1848 )
1849
1850 tens.extend(placeholders)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001851 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001852 tens.extend(
1853 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1854 )
1855 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001856
1857 if qgen is not None:
Les Bell30e46802021-07-23 09:43:31 +01001858 qinfo = qgen(self, op, dtype_or_dtypeList)
Eric Kunzee5e26762020-10-13 16:11:07 -07001859 else:
1860 qinfo = None
1861
1862 try:
1863 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001864 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001865 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001866 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001867 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001868 print(
1869 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1870 build_fcn, tens, testArgs
1871 )
1872 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001873 raise e
1874
1875 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001876 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001877
1878 def createDynamicOpLists(self):
1879
1880 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001881 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001882
1883 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001884 testName = "conv2d_{}x{}".format(k[0], k[1])
1885 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1886 self.TOSA_OP_LIST[testName]["filter"] = k
1887 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001888
Kevin Cheng550ccc52021-03-03 11:21:43 -08001889 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1890 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1891 "depthwise_conv2d_TEMPLATE"
1892 ].copy()
1893 self.TOSA_OP_LIST[testName]["filter"] = k
1894 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001895
Kevin Cheng550ccc52021-03-03 11:21:43 -08001896 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1897 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1898 "transpose_conv2d_TEMPLATE"
1899 ].copy()
1900 self.TOSA_OP_LIST[testName]["filter"] = k
1901 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001902
1903 # Delete any templates after having created any dynamic ops
1904 # This is a two-pass operation because it's bad practice to delete
1905 # keys from dictionaries while iterating
1906 keyList = []
1907 for k in self.TOSA_OP_LIST:
1908 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001909 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001910 keyList.append(k)
1911 continue
1912 except KeyError:
1913 pass
1914
1915 for k in keyList:
1916 del self.TOSA_OP_LIST[k]
1917
1918 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001919 """Fill in default fields for ops if they aren't already specified.
1920 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001921 for op in self.TOSA_OP_LIST:
1922
1923 # Required fields
1924 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001925 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001926 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001927 raise Exception(
1928 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
1929 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001930
1931 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001932 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001933 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001934 raise Exception(
1935 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
1936 op
1937 )
1938 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001939
1940 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001941 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001942 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001943 raise Exception(
1944 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
1945 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001946
1947 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001948 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001949 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001950 raise Exception(
1951 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
1952 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001953
1954 # Put in default rank range, if missing
1955 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001956 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001957 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001958 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07001959
1960 # Tensor operator list
1961 # 'op': op name
1962 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08001963 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
1964 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001965 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
1966 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08001967 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001968
Kevin Cheng550ccc52021-03-03 11:21:43 -08001969 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
1970 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07001971
Kevin Cheng550ccc52021-03-03 11:21:43 -08001972 TYPE_BOOL = [DType.BOOL]
1973 TYPE_FI32 = [DType.FLOAT, DType.INT32]
1974 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
1975 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07001976
Kevin Cheng550ccc52021-03-03 11:21:43 -08001977 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001978
Kevin Cheng989cb052021-04-28 16:29:44 -07001979 TYPE_CONV2D = [
1980 [DType.INT8, DType.INT8, DType.INT32],
1981 [DType.INT16, DType.INT8, DType.INT48],
1982 DType.FLOAT,
1983 ]
1984
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001985 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07001986
1987 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08001988 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08001989 "argmax": {
1990 "op": Op.ARGMAX,
1991 "operands": (1, 0),
1992 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
1993 "types": TYPE_NARROW_INT_FP,
1994 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001995 "avg_pool2d": {
1996 "op": Op.AVG_POOL2D,
1997 "operands": (1, 0),
1998 "rank": (4, 4),
1999 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2000 "qgen": TosaQuantGen.qgUnary,
2001 "types": TYPE_NARROW_INT_FP,
2002 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002003 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002004 "conv2d_TEMPLATE": {
2005 "op": Op.CONV2D,
2006 "operands": (1, 2),
2007 "rank": (4, 4),
2008 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
2009 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002010 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002011 "template": True,
2012 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002013 # Conv3d TBD
Eric Kunzee5e26762020-10-13 16:11:07 -07002014 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002015 "depthwise_conv2d_TEMPLATE": {
2016 "op": Op.DEPTHWISE_CONV2D,
2017 "operands": (1, 2),
2018 "filter": [1, 1],
2019 "rank": (4, 4),
2020 "build_fcn": (
2021 build_depthwise_conv2d,
2022 TosaTensorGen.tgDepthwiseConv2D,
2023 TosaArgGen.agConv2D,
2024 ),
2025 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002026 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002027 "template": True,
2028 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002029 "fully_connected": {
2030 "op": Op.FULLY_CONNECTED,
2031 "operands": (1, 2),
2032 "rank": (2, 2),
2033 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
2034 "qgen": TosaQuantGen.qgConv,
2035 "types": TYPE_CONV2D,
2036 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002037 "matmul": {
2038 "op": Op.MATMUL,
2039 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002040 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08002041 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
2042 "qgen": TosaQuantGen.qgMatmul,
2043 "types": TYPE_NARROW_INT_FP,
2044 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002045 "max_pool2d": {
2046 "op": Op.MAX_POOL2D,
2047 "operands": (1, 0),
2048 "rank": (4, 4),
2049 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2050 "types": TYPE_NARROW_INT_FP,
2051 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002052 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002053 "transpose_conv2d_TEMPLATE": {
2054 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002055 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002056 "rank": (4, 4),
2057 "build_fcn": (
2058 build_transpose_conv2d,
2059 TosaTensorGen.tgTransposeConv2D,
2060 TosaArgGen.agTransposeConv2D,
2061 ),
2062 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002063 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002064 "template": True,
2065 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002066 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002067 "clamp": {
2068 "op": Op.CLAMP,
2069 "operands": (1, 0),
2070 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2071 "types": TYPE_NARROW_INT_FP,
2072 },
2073 "relun": {
2074 "op": Op.RELUN,
2075 "operands": (1, 0),
2076 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2077 "types": TYPE_FI32,
2078 },
2079 "sigmoid": {
2080 "op": Op.SIGMOID,
2081 "operands": (1, 0),
2082 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2083 "types": TYPE_FP,
2084 },
2085 "tanh": {
2086 "op": Op.TANH,
2087 "operands": (1, 0),
2088 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2089 "types": TYPE_FP,
2090 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002091 # Elementwise Binary Operators
2092 "add": {
2093 "op": Op.ADD,
2094 "operands": (2, 0),
2095 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2096 "types": TYPE_FI32,
2097 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002098 "arithmetic_right_shift": {
2099 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2100 "operands": (2, 0),
2101 "build_fcn": (
2102 build_arithmetic_right_shift,
2103 TosaTensorGen.tgBroadcastFuzz,
2104 TosaArgGen.agArithmeticRightShift,
2105 ),
2106 "types": TYPE_INT,
2107 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002108 "bitwise_and": {
2109 "op": Op.BITWISE_AND,
2110 "operands": (2, 0),
2111 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2112 "types": TYPE_INT,
2113 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002114 "bitwise_or": {
2115 "op": Op.BITWISE_OR,
2116 "operands": (2, 0),
2117 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2118 "types": TYPE_INT,
2119 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002120 "bitwise_xor": {
2121 "op": Op.BITWISE_XOR,
2122 "operands": (2, 0),
2123 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2124 "types": TYPE_INT,
2125 },
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002126 "div": {
2127 "op": Op.DIV,
2128 "operands": (2, 0),
2129 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2130 "types": [DType.INT32],
2131 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002132 "logical_and": {
2133 "op": Op.LOGICAL_AND,
2134 "operands": (2, 0),
2135 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2136 "types": TYPE_BOOL,
2137 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002138 "logical_left_shift": {
2139 "op": Op.LOGICAL_LEFT_SHIFT,
2140 "operands": (2, 0),
2141 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2142 "types": TYPE_INT,
2143 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002144 "logical_right_shift": {
2145 "op": Op.LOGICAL_RIGHT_SHIFT,
2146 "operands": (2, 0),
2147 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2148 "types": TYPE_INT,
2149 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002150 "logical_or": {
2151 "op": Op.LOGICAL_OR,
2152 "operands": (2, 0),
2153 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2154 "types": TYPE_BOOL,
2155 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002156 "logical_xor": {
2157 "op": Op.LOGICAL_XOR,
2158 "operands": (2, 0),
2159 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2160 "types": TYPE_BOOL,
2161 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002162 "maximum": {
2163 "op": Op.MAXIMUM,
2164 "operands": (2, 0),
2165 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2166 "types": TYPE_FI32,
2167 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002168 "minimum": {
2169 "op": Op.MINIMUM,
2170 "operands": (2, 0),
2171 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2172 "types": TYPE_FI32,
2173 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002174 "mul": {
2175 "op": Op.MUL,
2176 "operands": (2, 0),
2177 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2178 "types": TYPE_INT_FP,
2179 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002180 "pow": {
2181 "op": Op.POW,
2182 "operands": (2, 0),
2183 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2184 "types": TYPE_FP,
2185 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002186 "sub": {
2187 "op": Op.SUB,
2188 "operands": (2, 0),
2189 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2190 "types": TYPE_FI32,
2191 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002192 "table": {
2193 "op": Op.TABLE,
2194 # Use the automatic generation functions to create the input array
2195 # but create the table tensor in the build function, as it may be
2196 # a different type from the input
2197 "operands": (1, 0),
2198 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002199 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08002200 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002201 # Elementwise Unary operators
2202 "abs": {
2203 "op": Op.ABS,
2204 "operands": (1, 0),
2205 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2206 "types": TYPE_FI32,
2207 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002208 "bitwise_not": {
2209 "op": Op.BITWISE_NOT,
2210 "operands": (1, 0),
2211 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2212 "types": TYPE_INT,
2213 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002214 "ceil": {
2215 "op": Op.CEIL,
2216 "operands": (1, 0),
2217 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2218 "types": TYPE_FP,
2219 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002220 "clz": {
2221 "op": Op.CLZ,
2222 "operands": (1, 0),
2223 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2224 "types": [DType.INT32],
2225 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002226 "exp": {
2227 "op": Op.EXP,
2228 "operands": (1, 0),
2229 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2230 "types": TYPE_FP,
2231 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002232 "floor": {
2233 "op": Op.FLOOR,
2234 "operands": (1, 0),
2235 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2236 "types": TYPE_FP,
2237 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002238 "log": {
2239 "op": Op.LOG,
2240 "operands": (1, 0),
2241 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2242 "types": TYPE_FP,
2243 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002244 "logical_not": {
2245 "op": Op.LOGICAL_NOT,
2246 "operands": (1, 0),
2247 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2248 "types": TYPE_BOOL,
2249 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002250 "negate": {
2251 "op": Op.NEGATE,
2252 "operands": (1, 0),
2253 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2254 "qgen": TosaQuantGen.qgUnary,
2255 "types": TYPE_INT_FP,
2256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002257 "reciprocal": {
2258 "op": Op.RECIPROCAL,
2259 "operands": (1, 0),
2260 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2261 "types": TYPE_FP,
2262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002263 "rsqrt": {
2264 "op": Op.RSQRT,
2265 "operands": (1, 0),
2266 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2267 "types": TYPE_FP,
2268 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002269 # Elementwise Ternary operators
2270 "select": {
2271 "op": Op.SELECT,
2272 "operands": (3, 0),
2273 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2274 "types": TYPE_FIB,
2275 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002276 # Comparison operators
2277 "equal": {
2278 "op": Op.EQUAL,
2279 "operands": (2, 0),
2280 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2281 "types": TYPE_FI32,
2282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002283 "greater_equal": {
2284 "op": Op.GREATER_EQUAL,
2285 "operands": (2, 0),
2286 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2287 "types": TYPE_FI32,
2288 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002289 "greater": {
2290 "op": Op.GREATER,
2291 "operands": (2, 0),
2292 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2293 "types": TYPE_FI32,
2294 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002295 # Reduction operators
2296 "reduce_all": {
2297 "op": Op.REDUCE_ALL,
2298 "operands": (1, 0),
2299 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2300 "types": TYPE_BOOL,
2301 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002302 "reduce_any": {
2303 "op": Op.REDUCE_ANY,
2304 "operands": (1, 0),
2305 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2306 "types": TYPE_BOOL,
2307 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002308 "reduce_max": {
2309 "op": Op.REDUCE_MAX,
2310 "operands": (1, 0),
2311 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2312 "types": TYPE_INT_FP,
2313 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002314 "reduce_min": {
2315 "op": Op.REDUCE_MAX,
2316 "operands": (1, 0),
2317 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2318 "types": TYPE_INT_FP,
2319 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002320 "reduce_product": {
2321 "op": Op.REDUCE_PRODUCT,
2322 "operands": (1, 0),
2323 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2324 "types": TYPE_FP,
2325 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002326 "reduce_sum": {
2327 "op": Op.REDUCE_SUM,
2328 "operands": (1, 0),
2329 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2330 "types": TYPE_FI32,
2331 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002332 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002333 "concat": {
2334 "op": Op.CONCAT,
2335 "operands": (2, 0),
2336 "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2337 "types": TYPE_FIB,
2338 },
2339 "pad": {
2340 "op": Op.PAD,
2341 "operands": (1, 0),
2342 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2343 "qgen": TosaQuantGen.qgPad,
2344 "types": TYPE_FIB,
2345 },
2346 "reshape": {
2347 "op": Op.RESHAPE,
2348 "operands": (1, 0),
2349 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2350 "types": TYPE_FIB,
2351 },
2352 "reverse": {
2353 "op": Op.REVERSE,
2354 "operands": (1, 0),
2355 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2356 "types": TYPE_FIB,
2357 },
2358 "slice": {
2359 "op": Op.SLICE,
2360 "operands": (1, 0),
2361 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2362 "types": TYPE_FIB,
2363 },
2364 "tile": {
2365 "op": Op.TILE,
2366 "operands": (1, 0),
2367 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2368 "types": TYPE_FIB,
2369 },
2370 "transpose": {
2371 "op": Op.TRANSPOSE,
2372 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01002373 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002374 "build_fcn": (
2375 build_transpose,
2376 TosaTensorGen.tgBasic,
2377 TosaArgGen.agTranspose,
2378 ),
2379 "types": TYPE_FIB,
2380 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002381 # Data nodes
2382 "const": {
2383 "op": Op.CONST,
2384 "operands": (1, 0),
2385 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2386 "types": TYPE_FIB,
2387 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002388 "identity": {
2389 "op": Op.IDENTITY,
2390 "operands": (1, 0),
2391 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2392 "types": TYPE_FIB,
2393 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002394 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002395 "gather": {
2396 "op": Op.GATHER,
2397 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2398 "operands": (1, 0),
2399 "rank": (3, 3),
2400 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2401 "types": TYPE_INT_FP,
2402 },
2403 "scatter": {
2404 "op": Op.SCATTER,
2405 # Only specify 'values_in' tensor here.
2406 #'indices' and 'input' are generated in op building stage
2407 "operands": (2, 0),
2408 "rank": (3, 3),
2409 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2410 "types": TYPE_INT_FP,
2411 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002412 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002413 "resize": {
2414 "op": Op.RESIZE,
2415 "operands": (1, 0),
2416 "rank": (4, 4),
2417 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2418 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2419 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002420 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002421 "cast": {
2422 "op": Op.CAST,
2423 "operands": (1, 0),
2424 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2425 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2426 },
2427 "rescale": {
2428 "op": Op.RESCALE,
2429 "operands": (1, 0),
2430 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002431 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002432 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002433 # Custom
2434 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002435 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002436 # Two varients of cond_if, one that generates one of two constant tensors (no
2437 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2438 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002439 "cond_if_const": {
2440 "op": Op.COND_IF,
2441 "operands": (0, 2),
2442 "build_fcn": (
2443 build_cond_if_const,
2444 TosaTensorGen.tgBasic,
2445 TosaArgGen.agCondIf,
2446 ),
2447 "types": [DType.BOOL],
2448 },
2449 "cond_if_binary": {
2450 "op": Op.COND_IF,
2451 "operands": (2, 0),
2452 "build_fcn": (
2453 build_cond_if_binary,
2454 TosaTensorGen.tgBasic,
2455 TosaArgGen.agCondIf,
2456 ),
2457 "types": TYPE_FI32,
2458 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002459 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002460 "while_loop": {
2461 "op": Op.WHILE_LOOP,
2462 "operands": (0, 1),
2463 "build_fcn": (
2464 build_while_loop,
2465 TosaTensorGen.tgBasic,
2466 TosaArgGen.agWhileLoop,
2467 ),
2468 "types": [DType.INT32],
2469 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002470 }
2471
Kevin Cheng550ccc52021-03-03 11:21:43 -08002472
Eric Kunzee5e26762020-10-13 16:11:07 -07002473class OutputShaper:
2474 # Methods in this class compute the expected output shape and datatype
2475 # for common classes of operations
2476 def __init__(self):
2477 pass
2478
2479 # These methods return arguments that can be used for
2480 # creating a new output tensor
2481 @staticmethod
2482 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002483 assert len(a.shape) == len(b.shape)
2484 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002485
2486 shape = []
2487 for i in range(len(a.shape)):
2488 if a.shape[i] == 1:
2489 shape.append(b.shape[i])
2490 else:
2491 shape.append(a.shape[i])
2492
Kevin Cheng550ccc52021-03-03 11:21:43 -08002493 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002494
2495 @staticmethod
2496 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002497 assert len(a.shape) == len(b.shape)
2498 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002499
2500 shape = []
2501 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002502 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002503 shape.append(a.shape[i])
2504
Kevin Cheng550ccc52021-03-03 11:21:43 -08002505 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002506
2507 @staticmethod
2508 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002509 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002510
2511 @staticmethod
2512 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002513 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2514 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002515
2516 shape = []
2517 for i in range(len(a.shape)):
2518 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2519
Kevin Cheng550ccc52021-03-03 11:21:43 -08002520 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002521
2522 @staticmethod
2523 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002524 assert len(a.shape) == len(b.shape)
2525 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002526
2527 # Do broadcast
2528 shape = []
2529 for i in range(len(a.shape)):
2530 if a.shape[i] == 1:
2531 shape.append(b.shape[i])
2532 else:
2533 shape.append(a.shape[i])
2534
2535 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002536 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002537
2538 @staticmethod
2539 def reduceOp(ser, a, axis):
2540
2541 shape = a.shape.copy()
2542
2543 shape[axis] = 1
2544
Kevin Cheng550ccc52021-03-03 11:21:43 -08002545 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002546
2547 @staticmethod
2548 def argmaxOp(ser, a, axis):
2549 shape = a.shape.copy()
2550 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002551 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002552
2553 @staticmethod
2554 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2555
2556 # IFM: NHWC
2557 # Filter: OHWI
2558 # OFM: NHWC
2559
2560 if len(padding) == 2:
2561 # Expand padding to 4 parameters in the case of transpose_conv2d
2562 # From H,W to T,B,L,R
2563 padding = [padding[0], padding[0], padding[1], padding[1]]
2564
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 h = (
2566 ifm.shape[1]
2567 - filter.shape[1]
2568 - (filter.shape[1] - 1) * (dilations[0] - 1)
2569 + padding[0]
2570 + padding[1]
2571 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002572
Kevin Cheng550ccc52021-03-03 11:21:43 -08002573 w = (
2574 ifm.shape[2]
2575 - filter.shape[2]
2576 - (filter.shape[2] - 1) * (dilations[1] - 1)
2577 + padding[2]
2578 + padding[3]
2579 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002580
2581 if h <= 0 or w <= 0:
2582 # Invalid test parameters?
2583 h = 0
2584 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002585 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002586
2587 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2588
Kevin Cheng3a478572021-01-22 17:21:02 -08002589 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002590 out_dtype = DType.INT32
2591 elif ifm.dtype == DType.INT16:
2592 out_dtype = DType.INT48
2593 elif ifm.dtype == DType.FLOAT:
2594 out_dtype = DType.FLOAT
2595 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002596 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002597
Kevin Cheng550ccc52021-03-03 11:21:43 -08002598 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002599
2600 @staticmethod
2601 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2602 # IFM: NHWC
2603 # Filter: HWCM
2604 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002605 h = (
2606 ifm.shape[1]
2607 - filter.shape[0]
2608 - (filter.shape[0] - 1) * (dilations[0] - 1)
2609 + padding[0]
2610 + padding[1]
2611 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002612
Kevin Cheng550ccc52021-03-03 11:21:43 -08002613 w = (
2614 ifm.shape[2]
2615 - filter.shape[1]
2616 - (filter.shape[1] - 1) * (dilations[1] - 1)
2617 + padding[2]
2618 + padding[3]
2619 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002620
2621 if h <= 0 or w <= 0:
2622 # Invalid test parameters?
2623 h = 0
2624 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002625 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002626
2627 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2628
Kevin Cheng3a478572021-01-22 17:21:02 -08002629 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002630 out_dtype = DType.INT32
2631 elif ifm.dtype == DType.INT16:
2632 out_dtype = DType.INT48
2633 elif ifm.dtype == DType.FLOAT:
2634 out_dtype = DType.FLOAT
2635 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002636 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002637
Kevin Cheng550ccc52021-03-03 11:21:43 -08002638 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002639
2640 @staticmethod
2641 def pool2dOp(ser, ifm, kernel, stride, pad):
2642 # input: NHWC
2643 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2644 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2645
2646 if h <= 0 or w <= 0:
2647 # Invalid test parameters?
2648 h = 0
2649 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002650 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002651
2652 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002653 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002654
2655 @staticmethod
2656 def fullyConnectedOp(ser, input, filter):
2657 # input: N, IC
2658 # filter: OC, IC
2659 # output: N, OC
2660
2661 output_shape = [input.shape[0], filter.shape[0]]
2662
Kevin Cheng3a478572021-01-22 17:21:02 -08002663 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002664 out_dtype = DType.INT32
2665 elif input.dtype == DType.INT16:
2666 out_dtype = DType.INT48
2667 elif input.dtype == DType.FLOAT:
2668 out_dtype = DType.FLOAT
2669 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002670 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002671
Kevin Cheng550ccc52021-03-03 11:21:43 -08002672 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002673
2674 @staticmethod
2675 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07002676 # a: N, H, C
2677 # b: N, C, W
2678 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07002679
Kevin Cheng2d60f002021-06-09 14:18:32 -07002680 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002681
Kevin Cheng3a478572021-01-22 17:21:02 -08002682 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002683 out_dtype = DType.INT32
2684 elif a.dtype == DType.INT16:
2685 out_dtype = DType.INT48
2686 elif a.dtype == DType.FLOAT:
2687 out_dtype = DType.FLOAT
2688 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002689 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002690
Kevin Cheng550ccc52021-03-03 11:21:43 -08002691 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002692
2693 @staticmethod
2694 def concatOp(ser, a, b, axis):
2695
2696 output_shape = a.shape.copy()
2697 output_shape[axis] = a.shape[axis] + b.shape[axis]
2698
Kevin Cheng550ccc52021-03-03 11:21:43 -08002699 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002700
2701 @staticmethod
2702 def padOp(ser, a, padding):
2703
2704 output_shape = a.shape.copy()
2705
2706 for i in range(len(output_shape)):
2707 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2708
Kevin Cheng550ccc52021-03-03 11:21:43 -08002709 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002710
2711 @staticmethod
2712 def reshapeOp(ser, a, shape):
2713 output_shape = shape.copy()
2714
2715 totalElements = 1
2716 for i in a.shape:
2717 totalElements *= i
2718
2719 # If there are any -1 elements, figure out what that dimension must be
2720 totalOutputElements = 1
2721 for i in output_shape:
2722 if i != -1:
2723 totalOutputElements *= i
2724
2725 # And fill it in
2726 for i in range(len(output_shape)):
2727 if output_shape[i] == -1:
2728 output_shape[i] = totalElements // totalOutputElements
2729
Kevin Cheng550ccc52021-03-03 11:21:43 -08002730 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002731
2732 @staticmethod
2733 def sliceOp(ser, a, begin, size):
2734
2735 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002736 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002737
2738 @staticmethod
2739 def tileOp(ser, a, multiples):
2740
2741 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002742 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002743
2744 for i in range(len(output_shape)):
2745 output_shape[i] = a.shape[i] * multiples[i]
2746
Kevin Cheng550ccc52021-03-03 11:21:43 -08002747 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002748
2749 @staticmethod
2750 def transposeOp(ser, a, perms):
2751 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002752 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002753
2754 for i in range(len(output_shape)):
2755 output_shape[i] = a.shape[perms[i]]
2756
Kevin Cheng550ccc52021-03-03 11:21:43 -08002757 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002758
2759 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002760 def gatherOp(ser, values, indices):
2761 assert len(values.shape) == 3
2762 assert len(indices.shape) == 2
2763 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002764
Kevin Cheng77d0f762020-11-24 10:26:32 -08002765 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2766
Kevin Cheng550ccc52021-03-03 11:21:43 -08002767 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002768
2769 @staticmethod
2770 def scatterOp(ser, values_in, indices, input):
2771 assert len(values_in.shape) == 3
2772 assert len(indices.shape) == 2
2773 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002774 assert values_in.shape[0] == indices.shape[0] # N
2775 assert input.shape[1] == indices.shape[1] # W
2776 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002777
2778 output_shape = values_in.shape
2779
Kevin Cheng550ccc52021-03-03 11:21:43 -08002780 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002781
2782 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002783 def tableOp(ser, input, table_dtype):
2784 # Same shape as the input, but dtype dependent on table dtype
2785 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
2786 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
2787 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002788
2789 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002790 def resizeOp(
2791 ser,
2792 input,
2793 mode,
2794 stride,
2795 offset,
2796 shift,
2797 stride_fp,
2798 offset_fp,
2799 output_dims,
2800 input_dtype,
2801 output_dtype,
2802 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002803
2804 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2805
Kevin Cheng77d0f762020-11-24 10:26:32 -08002806 if input_dtype == DType.FLOAT:
2807 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002808 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002809 else:
2810 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002811 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002812
Kevin Chengaee1fac2020-11-11 13:54:06 -08002813 if mode == ResizeMode.BILINEAR:
2814 if input_dtype == DType.INT8:
2815 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002816 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002817 elif input_dtype == DType.INT16:
2818 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002819 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002820 elif input_dtype == DType.FLOAT:
2821 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002822 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002823 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002824 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002825
2826 elif mode == ResizeMode.NEAREST:
2827 if input_dtype == DType.INT8:
2828 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002829 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002830 elif input_dtype == DType.INT16:
2831 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002832 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002833 elif input_dtype == DType.FLOAT:
2834 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002835 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002836 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002837 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002838
2839 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002840 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002841
Kevin Cheng550ccc52021-03-03 11:21:43 -08002842 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002843
2844 @staticmethod
2845 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002846 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002847
2848 @staticmethod
2849 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002850 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002851 out_dtype = DType.INT32
2852 elif ifm.dtype == DType.INT16:
2853 out_dtype = DType.INT48
2854 elif ifm.dtype == DType.FLOAT:
2855 out_dtype = DType.FLOAT
2856 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002857 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002858
2859 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002860 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002861
Kevin Cheng550ccc52021-03-03 11:21:43 -08002862 return ser.addOutput(output_shape, out_dtype)