blob: f2f9b63aebc07c344b883c16b3fd1431dd7e8816 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
35
Kevin Cheng550ccc52021-03-03 11:21:43 -080036# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa_serializer as ts
42from tosa_serializer import *
43import tosa
44
45# Convenience variables to the flatc-generated types that should be enums, but aren't
46DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070048ResizeMode = tosa.ResizeMode.ResizeMode()
49
Kevin Cheng550ccc52021-03-03 11:21:43 -080050
Eric Kunzee5e26762020-10-13 16:11:07 -070051class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080052 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
53
Eric Kunzee5e26762020-10-13 16:11:07 -070054 def __init__(self):
55 pass
56
57 @staticmethod
58 def needsQinfo(op, dtype):
Jared Smolens2a76ad22021-03-04 11:18:54 -080059 if dtype == DType.INT8 or dtype == DType.INT16:
Eric Kunzee5e26762020-10-13 16:11:07 -070060 return True
61 return False
62
63 @staticmethod
64 def qgUnary(testGen, op, dtype):
65 qinfo = ts.TosaSerializerQuantInfo()
66 if TosaQuantGen.needsQinfo(op, dtype):
67 qinfo.UnaryQuantInfo(testGen.randInt(), testGen.randInt())
68 else:
69 qinfo.UnaryQuantInfo(0, 0)
70 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype):
74 qinfo = ts.TosaSerializerQuantInfo()
75 if TosaQuantGen.needsQinfo(op, dtype):
76 qinfo.ConvQuantInfo(testGen.randInt(), testGen.randInt())
77 else:
78 qinfo.ConvQuantInfo(0, 0)
79 return qinfo
80
81 @staticmethod
82 def qgMatmul(testGen, op, dtype):
83 qinfo = ts.TosaSerializerQuantInfo()
84 if TosaQuantGen.needsQinfo(op, dtype):
85 qinfo.MatMulQuantInfo(testGen.randInt(), testGen.randInt())
86 else:
87 qinfo.MatMulQuantInfo(0, 0)
88 return qinfo
89
90 @staticmethod
91 def qgPad(testGen, op, dtype):
92 qinfo = ts.TosaSerializerQuantInfo()
93 if TosaQuantGen.needsQinfo(op, dtype):
94 qinfo.PadQuantInfo(testGen.randInt())
95 else:
96 qinfo.PadQuantInfo(0)
97 return qinfo
98
99 @staticmethod
100 def computeMultiplierAndShift(scaleFp, scale32):
101 # Derived from computeMultiplierAndShiftTosaScale32
102 # Provide a floating-point scaling factor and the scale32 parameter
103 # to compute the multiplier and shift
104
105 if scale32:
106 scaleBits = 31
107 else:
108 scaleBits = 15
109
110 m, shift = math.frexp(scaleFp)
111
112 if scaleFp < 0.0:
113 m = -m
114
115 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
118 if multiplier == (1 << scaleBits):
119 multiplier = multiplier // 2
120 shift = shift + 1
121
122 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800123 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 assert multiplier <= (1 << scaleBits)
126 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700127
128 return multiplier, shift
129
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131class TosaTensorGen:
132 """Tensor generators create a shape list for the placeholder and const tensor
133 data operands for the operator. The actual random data is generated separately for each test."""
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 def __init__(self):
136 pass
137
138 @staticmethod
139 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 shape = testGen.makeShape(rank)
142
143 shape_list = []
144 for i in range(pl + const):
145 shape_list.append(shape.copy())
146
147 return shape_list
148
149 @staticmethod
150 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
Kevin Cheng550ccc52021-03-03 11:21:43 -0800153 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155 shape = testGen.makeShape(rank)
156
157 # Constrict the batch size?
158 if testGen.args.max_batch_size:
159 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
160
161 shape_list = []
162 for i in range(pl + const):
163 shape_list.append(shape.copy())
164
165 return shape_list
166
167 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800168 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert pl == 2
172 assert const == 0
173 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800174
175 values_in_shape = testGen.makeShape(rank)
176
177 # Constrict the batch size?
178 if testGen.args.max_batch_size:
179 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
180
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 W = testGen.randInt(
182 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
183 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800184 input_shape = [values_in_shape[0], W, values_in_shape[2]]
185
186 shape_list = []
187 shape_list.append(values_in_shape.copy())
188 shape_list.append(input_shape.copy())
189
190 return shape_list
191
192 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 def tgBroadcastFuzz(testGen, op, rank):
194 shape = testGen.makeShape(rank)
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700197
198 shape_list = []
199
200 # Choose one of the inputs to broadcast
201 bcast_idx = testGen.randInt(0, pl + const)
202 for i in range(pl + const):
203 shape_bcast = shape.copy()
204
205 # If the chosen input, pick a random index to broadcast
206 if i == bcast_idx:
207 fuzz_idx = testGen.randInt(0, rank)
208 shape_bcast[fuzz_idx] = 1
209
210 shape_list.append(shape_bcast)
211
212 return shape_list
213
214 @staticmethod
215 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Kevin Cheng550ccc52021-03-03 11:21:43 -0800218 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 # IFM dimensions are NHWC
221 ifm_shape = testGen.makeShape(rank)
222
223 # Constrict the batch size?
224 if testGen.args.max_batch_size:
225 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
226
227 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
230 # Generate a random OFM depth
231 ofm_depth = testGen.makeShape(1)[0]
232
233 # The filter dimensions are OHWI
234 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
235
236 # The bias is OC
237 bias_shape = np.asarray([ofm_depth])
238
239 return [ifm_shape, filter_shape, bias_shape]
240
241 @staticmethod
242 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700244
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700246
247 # IFM dimensions are NHWC
248 ifm_shape = testGen.makeShape(rank)
249
250 # Constrict the batch size?
251 if testGen.args.max_batch_size:
252 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
253
254 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800255 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700256
257 # Generate a random OFM depth
258 ofm_depth = testGen.makeShape(1)[0]
259
260 # The filter dimensions are OHWI
261 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
262
Kevin Cheng989cb052021-04-28 16:29:44 -0700263 # The bias is OC
264 bias_shape = np.asarray([ofm_depth])
265
266 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700267
268 @staticmethod
269 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 assert rank == 4
273 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 # IFM dimensions are NHWC
276 ifm_shape = testGen.makeShape(rank)
277
278 # Constrict the batch size?
279 if testGen.args.max_batch_size:
280 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
281
282 # Get the filter height/width from the operator parameters
283 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286 # Generate a random OFM depth, but don't let it get too big because
287 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 filter_m = (
289 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
290 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700291
292 # The filter dimensions are HWCM
293 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
294
295 # The bias is M * C
296 bias_shape = np.asarray([ifm_shape[3] * filter_m])
297
298 return [ifm_shape, filter_shape, bias_shape]
299
300 @staticmethod
301 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
306 input_shape = testGen.makeShape(rank)
307 filter_oc = testGen.makeShape(1)[0]
308 filter_shape = np.asarray([filter_oc, input_shape[1]])
309
310 bias_shape = np.asarray([filter_oc])
311
312 return [input_shape, filter_shape, bias_shape]
313
314 @staticmethod
315 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800316 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700317
Kevin Cheng2d60f002021-06-09 14:18:32 -0700318 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800319 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321 a_shape = testGen.makeShape(rank)
322 b_oc = testGen.makeShape(1)[0]
Kevin Cheng2d60f002021-06-09 14:18:32 -0700323 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700324
325 return [a_shape, b_shape]
326
Kevin Cheng550ccc52021-03-03 11:21:43 -0800327
Eric Kunzee5e26762020-10-13 16:11:07 -0700328class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800329 """Argument generators create exhaustive or random lists of attributes for operators that take
330 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
331 tuples where the descriptive_name is appended to the test name and the arglist is expanded
332 as arguments to the operator build function."""
333
Eric Kunzee5e26762020-10-13 16:11:07 -0700334 def __init__(self):
335 pass
336
337 @staticmethod
338 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800339 """A trivial argument generator for operators that don't take any
340 non-tensor arguments"""
341 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700342
343 @staticmethod
344 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800345 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700346 axes = []
347
348 shape = shapeList[0]
349
350 for a in range(0, len(shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800351 axes.append(("axis_{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700352 return axes
353
354 @staticmethod
355 def agConv2D(testGen, opName, shapeList, dtype):
356 arg_list = []
357
358 ifm_shape = shapeList[0]
359 filter_shape = shapeList[1]
360
361 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800362 assert len(ifm_shape) == 4
363 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700364
365 maxStride = testGen.args.max_conv_stride
366 maxPadding = testGen.args.max_conv_padding + 1
367 maxDilation = testGen.args.max_conv_dilation
368
369 # Strides, padding, dilations
370 for stride in range(0, maxStride ** 2):
371 for padding in range(0, (maxPadding) ** 4):
372 for dilation in range(0, maxDilation ** 2):
373
Kevin Cheng550ccc52021-03-03 11:21:43 -0800374 s = [stride // maxStride + 1, stride % maxStride + 1]
375 p = [
376 (padding // (maxPadding * 4)) % maxPadding,
377 (padding // (maxPadding * 2)) % maxPadding,
378 (padding // (maxPadding * 1)) % maxPadding,
379 padding % maxPadding,
380 ]
381 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700382
383 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800384 arg_list.append(
385 (
386 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
387 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
388 ),
389 [s, p, d],
390 )
391 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 return arg_list
393
394 @staticmethod
395 def agTransposeConv2D(testGen, opName, shapeList, dtype):
396 arg_list = []
397
398 ifm_shape = shapeList[0]
399 filter_shape = shapeList[1]
400
401 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800402 assert len(ifm_shape) == 4
403 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700404
405 maxStride = testGen.args.max_conv_stride
406 maxPadding = testGen.args.max_conv_padding + 1
407 maxDilation = testGen.args.max_conv_dilation
408
409 # Strides, padding, dilations
410 for stride in range(0, maxStride ** 2):
411 for out_padding in range(0, (maxPadding) ** 2):
412 for dilation in range(0, maxDilation ** 2):
413
Kevin Cheng550ccc52021-03-03 11:21:43 -0800414 s = [stride // maxStride + 1, stride % maxStride + 1]
415 p = [
416 (out_padding // (maxPadding * 1)) % maxPadding,
417 out_padding % maxPadding,
418 ]
419 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700420
Kevin Cheng550ccc52021-03-03 11:21:43 -0800421 oh = (
422 ifm_shape[1]
423 - filter_shape[1]
424 - (filter_shape[1] - 1) * (d[0] - 1)
425 + 2 * p[0]
426 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700427
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 ow = (
429 ifm_shape[2]
430 - filter_shape[2]
431 - (filter_shape[2] - 1) * (d[1] - 1)
432 + 2 * p[1]
433 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700434
435 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800436 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700437
Kevin Cheng550ccc52021-03-03 11:21:43 -0800438 arg_list.append(
439 (
440 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
441 s[0],
442 s[1],
443 p[0],
444 p[1],
445 d[0],
446 d[1],
447 os[0],
448 os[1],
449 os[2],
450 os[3],
451 ),
452 [s, p, d, os],
453 )
454 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700455
456 return arg_list
457
458 @staticmethod
459 def agPad(testGen, opName, shapeList, dtype):
460 arg_list = []
461 rank = len(shapeList[0])
462
463 # Exhaustively test combinations of 0/1 padding on each side of each dimension
464 # This process might need some revision for >1 padding, but use rank**2 as a bitmask
465 # for now
466 for v in range(rank ** 2):
467
468 # Create a flat arraypadding4D
469 paddings = np.zeros((rank * 2), dtype=np.int32)
470
471 # Fill in the 1's
Kevin Cheng550ccc52021-03-03 11:21:43 -0800472 for r in range(rank * 2):
Eric Kunzee5e26762020-10-13 16:11:07 -0700473 if (v >> r) & 1:
474 paddings[r] = 1
475
476 # Reshape back to a 2D array
477 paddings = paddings.reshape((rank, 2))
478
Kevin Cheng550ccc52021-03-03 11:21:43 -0800479 arg_list.append(("pad{0:b}".format(v), [paddings]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700480
481 return arg_list
482
483 @staticmethod
484 def agPooling(testGen, opName, shapeList, dtype):
485 arg_list = []
486
487 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800488 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700489
490 maxStride = testGen.args.max_pooling_stride
491 maxKernel = testGen.args.max_pooling_kernel
492 maxPadding = testGen.args.max_pooling_padding + 1
493
494 for kernel in range(0, maxKernel ** 2):
495 for stride in range(0, maxStride ** 2):
496 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800497 s = [stride // maxStride + 1, stride % maxStride + 1]
498 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
499 p = [
500 (padding // (maxPadding * 4)) % maxPadding,
501 (padding // (maxPadding * 2)) % maxPadding,
502 (padding // (maxPadding * 1)) % maxPadding,
503 padding % maxPadding,
504 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
Kevin Cheng550ccc52021-03-03 11:21:43 -0800506 arg_list.append(
507 (
508 "st{}{}_kern{}{}_pad{}{}{}{}".format(
509 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
510 ),
511 [k, s, p],
512 )
513 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700514 return arg_list
515
516 @staticmethod
517 def agCast(testGen, opName, shapeList, inDtype):
518 arg_list = []
519
520 # Enumerate the output types here
521 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800522 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800524 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800526 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700527 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800528 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700529 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800530 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700531 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800532 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
534 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800535 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700536
537 return arg_list
538
539 @staticmethod
540 def agRescale(testGen, opName, shapeList, inDtype):
541 arg_list = []
542
543 # Enumerate the output types here
Kevin Cheng550ccc52021-03-03 11:21:43 -0800544 for dtype in [DType.INT8, DType.INT16, DType.INT32]:
545 for scale32 in [False, True]:
546 for double_round in [False, True]:
547 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
549 if inDtype == DType.INT48 and scale32:
550 # Illegal condition. Must be scale32=False
551 continue
552
Kevin Cheng550ccc52021-03-03 11:21:43 -0800553 arg_list.append(
554 (
555 "out{}_sc{}_dr{}_pc{}".format(
556 DTypeNames[dtype],
557 int(scale32),
558 int(double_round),
559 int(per_channel),
560 ),
561 [dtype, scale32, double_round, per_channel],
562 )
563 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700564
565 return arg_list
566
Kevin Chengaee1fac2020-11-11 13:54:06 -0800567 @staticmethod
568 def agMul(testGen, opName, shapeList, dtype):
569 arg_list = []
570
571 if dtype is DType.INT32:
572 for p in range(testGen.args.num_rand_permutations):
573
574 shift = testGen.randInt(0, 32)
575
Kevin Cheng550ccc52021-03-03 11:21:43 -0800576 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800577 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800578 arg_list.append(("shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800579
580 return arg_list
581
582 @staticmethod
583 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
584 arg_list = []
585
Kevin Cheng550ccc52021-03-03 11:21:43 -0800586 arg_list.append(("roundTrue", [True]))
587 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800588
589 return arg_list
590
Eric Kunzee5e26762020-10-13 16:11:07 -0700591 # Helper function for reshape. Gets some factors of a larger number.
592 @staticmethod
593 def getFactors(val, start=1):
594 factors = []
595
596 for i in range(start, int(np.sqrt(val))):
597 if (val % i) == 0:
598 factors.append(i)
599
600 return factors
601
602 @staticmethod
603 def agReshape(testGen, opName, shapeList, dtype):
604 arg_list = []
605
606 origShape = shapeList[0]
607
608 totalElements = 1
609 for s in origShape:
610 totalElements *= s
611
612 # This code is NOT fast. Fortunately, the numbers are fairly small.
613 factors = TosaArgGen.getFactors(totalElements)
614
615 for p in range(testGen.args.num_rand_permutations):
616 newRank = testGen.randInt(1, 6)
617 newShape = []
Kevin Cheng550ccc52021-03-03 11:21:43 -0800618 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700619 continue
620
621 remainingElements = totalElements
622 shuffledFactors = testGen.rng.permutation(factors)
623 for i in range(newRank):
624 # pick rank-1 factors
625 newShape.append(shuffledFactors[0])
626 remainingElements = remainingElements // shuffledFactors[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800627 shuffledFactors = testGen.rng.permutation(
628 TosaArgGen.getFactors(remainingElements)
629 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700630 newShape.append(remainingElements)
631
632 # Toss in a -1 sometimes
633 minusOne = testGen.randInt(0, newRank * 4)
634 if minusOne < newRank:
635 newShape[minusOne] = -1
636
Kevin Cheng550ccc52021-03-03 11:21:43 -0800637 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700638
639 return arg_list
640
Eric Kunzee5e26762020-10-13 16:11:07 -0700641 @staticmethod
642 def agTranspose(testGen, opName, shapeList, dtype):
643 arg_list = []
644
645 ifm_shape = shapeList[0]
646
Jeremy Johnsona6185572021-06-21 15:55:35 +0100647 # Get all permutations
648 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700649
Jeremy Johnsona6185572021-06-21 15:55:35 +0100650 # Limit to possible permutations from shape dimension or argument setting
651 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700652
Jeremy Johnsona6185572021-06-21 15:55:35 +0100653 # Get random permutation generator that uses all permutations
654 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700655
Jeremy Johnsona6185572021-06-21 15:55:35 +0100656 # Create list of required amount of permutations
657 arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700658 return arg_list
659
660 @staticmethod
661 def agSlice(testGen, opName, shapeList, dtype):
662 arg_list = []
663
664 ifm_shape = shapeList[0]
665 rank = len(ifm_shape)
666
667 for p in range(testGen.args.num_rand_permutations):
668 begin = []
669 size = []
670
Kevin Cheng550ccc52021-03-03 11:21:43 -0800671 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700672
673 for i in range(rank):
674 if ifm_shape[i] > 1:
675 begin.append(testGen.randInt(0, ifm_shape[i]))
676 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
677
678 # Invalid slice size?
679 if size[i] == 0:
680 valid = False
681 else:
682 begin.append(0)
683 size.append(1)
684
685 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800686 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700687 return arg_list
688
689 @staticmethod
690 def agTile(testGen, opName, shapeList, dtype):
691 arg_list = []
692
693 ifm_shape = shapeList[0]
694 rank = len(ifm_shape)
695
696 for p in range(testGen.args.num_rand_permutations):
697
698 # Pick a few random, but small multiple values
699 # because otherwise this has a tendency to generate
700 # enormous tensors
701 multiples = []
702 for i in range(rank):
703 multiples.append(testGen.randInt(1, 4))
704
Kevin Cheng550ccc52021-03-03 11:21:43 -0800705 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700706
707 return arg_list
708
709 @staticmethod
710 def agResize(testGen, opName, shapeList, dtype):
711 arg_list = []
712
713 ifm_shape = shapeList[0]
714
715 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
716
717 # Exclude illegal {mode, type} configurations. Pick legal output types
718 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800719 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700720 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800721 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700722 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800723 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700724 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800725 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800726 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800727 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700728 else:
729 continue
730
731 for outputDType in outputDTypeList:
732 for perm in range(testGen.args.num_rand_permutations):
733
734 # Randomly generate legal output dimensions and shift
735 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800736 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800737 in_center_h = (ifm_shape[1] - 1) / 2.0
738 in_center_w = (ifm_shape[2] - 1) / 2.0
739 out_center_h = (output_dims[0] - 1) / 2.0
740 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700741
Kevin Cheng77d0f762020-11-24 10:26:32 -0800742 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
743 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
744 fp_offset_y = in_center_h - fp_stride_y * out_center_h
745 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700746
Kevin Cheng77d0f762020-11-24 10:26:32 -0800747 if outputDType == DType.FLOAT:
748 shift = 0
749 stride = [0, 0]
750 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800751 stride_fp = [fp_stride_y, fp_stride_x]
752 offset_fp = [fp_offset_y, fp_offset_x]
753 arg_list.append(
754 (
755 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
756 m,
757 output_dims[0],
758 output_dims[1],
759 testGen.typeStr(outputDType),
760 stride_fp[0],
761 stride_fp[1],
762 offset_fp[0],
763 offset_fp[1],
764 ),
765 [
766 m,
767 stride,
768 offset,
769 shift,
770 stride_fp,
771 offset_fp,
772 output_dims,
773 dtype,
774 outputDType,
775 ],
776 )
777 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800778 else:
779 shift = 11
780 unit = float(1 << shift)
781 stride_y = int(round(fp_stride_y * unit))
782 stride_x = int(round(fp_stride_x * unit))
783 offset_y = int(round(fp_offset_y * unit))
784 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700785
Kevin Cheng550ccc52021-03-03 11:21:43 -0800786 while (
787 stride_y >= 32768
788 or stride_x >= 32768
789 or offset_y >= 32768
790 or offset_x >= 32768
791 or offset_y < -32768
792 or offset_x < -32768
793 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800794 shift = shift - 1
795 unit = float(1 << shift)
796 stride_y = int(round(fp_stride_y * unit))
797 stride_x = int(round(fp_stride_x * unit))
798 offset_y = int(round(fp_offset_y * unit))
799 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700800
Kevin Cheng550ccc52021-03-03 11:21:43 -0800801 stride = [stride_y, stride_x]
802 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800803
804 stride_fp = [0.0, 0.0]
805 offset_fp = [0.0, 0.0]
806
Kevin Cheng550ccc52021-03-03 11:21:43 -0800807 arg_list.append(
808 (
809 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
810 m,
811 shift,
812 output_dims[0],
813 output_dims[1],
814 testGen.typeStr(outputDType),
815 stride[0],
816 stride[1],
817 offset[0],
818 offset[1],
819 ),
820 [
821 m,
822 stride,
823 offset,
824 shift,
825 stride_fp,
826 offset_fp,
827 output_dims,
828 dtype,
829 outputDType,
830 ],
831 )
832 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700833
834 return arg_list
835
836 def agCondIf(testGen, opName, shapeList, dtype):
837 # CondIf generates the condition values here.
838 # Convert to tensors in the build function, along with the
839 # then and else blocks
840 arg_list = []
841
842 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800843 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700844
845 return arg_list
846
847 def agWhileLoop(testGen, opName, shapeList, dtype):
848 # While loop: 0 iterations, 1, more than 1
849 arg_list = []
850
851 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800852 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700853
854 return arg_list
855
Kevin Cheng550ccc52021-03-03 11:21:43 -0800856
Eric Kunzee5e26762020-10-13 16:11:07 -0700857class TosaTestGen:
858 def __init__(self, args):
859 self.args = args
860 self.basePath = args.output_dir
861 self.random_seed = args.random_seed
862 self.ser = None
863 self.rng = np.random.default_rng(self.random_seed)
864 self.createDynamicOpLists()
865 self.initOpListDefaults()
866 self.quantGen = TosaQuantGen()
867 # Force makeShape to do a specific starting shape
868 self.targetted_shape = None
869
870 def createSerializer(self, opName, testPath):
871 self.testPath = os.path.join(opName, testPath)
872
873 fullPath = os.path.join(self.basePath, self.testPath)
874 os.makedirs(fullPath, exist_ok=True)
875 self.ser = ts.TosaSerializer(fullPath)
876
877 def getSerializer(self):
878 return self.ser
879
880 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800881 with open(
882 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
883 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700884 fd.write(self.ser.serialize())
885
Kevin Cheng550ccc52021-03-03 11:21:43 -0800886 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
887 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700888
889 def getRandTensor(self, shape, dtype):
890 RAND_SHIFT_FACTOR = 0.5
891 RAND_SCALE_FACTOR = 4.0
892
893 if dtype == DType.BOOL:
894 np_dt = np.bool
895 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700896 elif dtype == DType.INT4:
897 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
898 elif dtype == DType.INT8:
899 return np.int32(self.rng.integers(low=-127, high=128, size=shape))
900 elif dtype == DType.INT16:
901 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
902 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800903 return np.int32(
904 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
905 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700906 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800907 return np.int64(
908 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
909 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700910 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800911 return np.float32(
912 self.rng.random(size=shape) - RAND_SHIFT_FACTOR * RAND_SCALE_FACTOR
913 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700914 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800915 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700916
Kevin Cheng989cb052021-04-28 16:29:44 -0700917 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700918 placeholders = []
919
Kevin Cheng989cb052021-04-28 16:29:44 -0700920 assert len(shape_list) == len(dtype_list)
921
922 for idx, shape in enumerate(shape_list):
923 arr = self.getRandTensor(shape, dtype_list[idx])
924 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700925
926 return placeholders
927
Kevin Cheng989cb052021-04-28 16:29:44 -0700928 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700929 consts = []
930
Kevin Cheng989cb052021-04-28 16:29:44 -0700931 assert len(shape_list) == len(dtype_list)
932
933 for idx, shape in enumerate(shape_list):
934 arr = self.getRandTensor(shape, dtype_list[idx])
935 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700936
937 return consts
938
939 def makeShape(self, rank):
940 if self.targetted_shape:
941 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800942 return np.int32(
943 self.rng.integers(
944 low=self.args.tensor_shape_range[0],
945 high=self.args.tensor_shape_range[1],
946 size=rank,
947 )
948 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700949
950 def setTargetShape(self, shape):
951 self.targetted_shape = shape
952
953 def randInt(self, low=0, high=256):
954 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
955
956 def getRandNumberDType(self, dtype):
957 if dtype == DType.FLOAT:
958 return self.rng.random()
959 elif dtype == DType.BOOL:
960 return self.rng.choice([False, True])
961 elif dtype == DType.INT4:
962 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700963 elif dtype == DType.INT8:
964 low, high = (-127, 128)
965 elif dtype == DType.INT16:
966 low, high = (-32768, 32768)
967 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800968 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700969 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800970 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700971 # Special size
972 return np.int64(self.rng.integers(low, high, size=1))[0]
973 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800974 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700975
976 return np.int32(self.rng.integers(low, high, size=1))[0]
977
978 def shapeStr(self, shape):
979
980 sStr = []
981 # Convert to strings
982 for i in shape:
983 sStr.append(str(i))
984
Kevin Cheng550ccc52021-03-03 11:21:43 -0800985 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700986
987 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700988 if isinstance(t, list):
989 assert len(t) >= 2
990 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700991 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700992 if t == DType.BOOL:
993 return "b"
994 elif t == DType.INT4:
995 return "i4"
996 elif t == DType.INT8:
997 return "i8"
998 elif t == DType.UINT8:
999 return "u8"
1000 elif t == DType.INT16:
1001 return "i16"
1002 elif t == DType.INT32:
1003 return "i32"
1004 elif t == DType.INT48:
1005 return "i48"
1006 elif t == DType.FLOAT:
1007 return "float"
1008 else:
1009 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001010
1011 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001012 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001013 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001014 return 4
1015 elif t == DType.INT8:
1016 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001017 elif t == DType.UINT8:
1018 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001019 elif t == DType.INT16:
1020 return 16
1021 elif t == DType.INT32:
1022 return 32
1023 elif t == DType.INT48:
1024 return 48
1025 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001026 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001027
1028 # Argument generators
1029 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1030 # Where the string descriptor is used to generate the test name and
1031 # The build_fcn_arg_list is expanded and passed to the operator test
1032 # build function
1033
Kevin Cheng550ccc52021-03-03 11:21:43 -08001034 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001035 result_tens = OutputShaper.unaryOp(self.ser, a)
1036 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1037 return result_tens
1038
1039 def build_binary_broadcast(self, op, a, b):
1040 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1041 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1042 return result_tens
1043
1044 def build_binary_nonbroadcast(self, op, a, b):
1045 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1046 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1047 return result_tens
1048
Kevin Chengaee1fac2020-11-11 13:54:06 -08001049 def build_arithmetic_right_shift(self, op, a, b, round):
1050 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1051
1052 attr = ts.TosaSerializerAttribute()
1053 attr.ArithmeticRightShiftAttribute(round)
1054
1055 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1056 return result_tens
1057
1058 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001059 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1060
1061 # Special for multiply:
1062 # Force the result to INT32 for INT types
1063 if a.dtype != DType.FLOAT:
1064 result_tens.setDtype(DType.INT32)
1065
Kevin Chengaee1fac2020-11-11 13:54:06 -08001066 attr = ts.TosaSerializerAttribute()
1067 attr.MulAttribute(shift)
1068
1069 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001070 return result_tens
1071
1072 def build_table(self, op, a):
1073 # Constant size, random values
1074 table_arr = self.getRandTensor([513], DType.INT16)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001075 table_tens = self.ser.addConst(table_arr.shape, DType.INT16, table_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001076
1077 result_tens = OutputShaper.tableOp(self.ser, a, table_tens)
1078 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1079
1080 return result_tens
1081
1082 def build_select(self, op, cond, a, b):
1083
1084 # Replace the cond tensor with a boolean tensor since it probably
1085 # has the wrong dtype
Kevin Cheng989cb052021-04-28 16:29:44 -07001086 t = self.buildPlaceholderTensors([cond.shape], [DType.BOOL])
Eric Kunzee5e26762020-10-13 16:11:07 -07001087 cond = t[0]
1088
1089 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1090 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
1091
1092 return result_tens
1093
1094 def build_comparison(self, op, a, b):
1095 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1096 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1097 return result_tens
1098
1099 def build_argmax(self, op, a, axis):
1100 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1101
1102 attr = ts.TosaSerializerAttribute()
1103 attr.AxisAttribute(axis)
1104
1105 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1106 return result_tens
1107
Kevin Cheng550ccc52021-03-03 11:21:43 -08001108 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001109 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1110
1111 attr = ts.TosaSerializerAttribute()
1112 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001113
1114 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1115 return result_tens
1116
1117 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001118 assert len(padding) == 4
1119 result_tens = OutputShaper.conv2dOp(
1120 self.ser, ifm, filter, strides, padding, dilations
1121 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001122
1123 attr = ts.TosaSerializerAttribute()
1124 attr.Conv2dAttribute(padding, strides, dilations)
1125
Kevin Cheng550ccc52021-03-03 11:21:43 -08001126 self.ser.addOperator(
1127 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1128 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001129 return result_tens
1130
Kevin Cheng550ccc52021-03-03 11:21:43 -08001131 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001132 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001133 ):
1134 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001135 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1136
1137 attr = ts.TosaSerializerAttribute()
1138 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1139
Kevin Cheng550ccc52021-03-03 11:21:43 -08001140 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001141 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001142 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001143 return result_tens
1144
Kevin Cheng550ccc52021-03-03 11:21:43 -08001145 def build_depthwise_conv2d(
1146 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1147 ):
1148 result_tens = OutputShaper.depthwiseConv2dOp(
1149 self.ser, ifm, filter, strides, padding, dilations
1150 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001151
1152 attr = ts.TosaSerializerAttribute()
1153 attr.Conv2dAttribute(padding, strides, dilations)
1154
Kevin Cheng550ccc52021-03-03 11:21:43 -08001155 self.ser.addOperator(
1156 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1157 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001158 return result_tens
1159
1160 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1161 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1162
Kevin Cheng550ccc52021-03-03 11:21:43 -08001163 self.ser.addOperator(
1164 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1165 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001166 return result_tens
1167
1168 def build_matmul(self, op, a, b, qinfo):
1169 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1170 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1171 return result_tens
1172
1173 def build_reduce(self, op, a, axis):
1174 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1175
1176 attr = ts.TosaSerializerAttribute()
1177 attr.AxisAttribute(axis)
1178
1179 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1180 return result_tens
1181
1182 def build_clamp(self, op, a):
1183 result_tens = OutputShaper.unaryOp(self.ser, a)
1184
1185 attr = ts.TosaSerializerAttribute()
1186
1187 # Get two random ints
1188 v = [self.randInt(), self.randInt()]
1189
1190 if a.dtype == DType.FLOAT:
1191 attr.ClampAttribute(0, 0, min(v), max(v))
1192 else:
1193 attr.ClampAttribute(min(v), max(v), 0, 0)
1194
1195 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1196 return result_tens
1197
1198 def build_leaky_relu(self, op, a):
1199 result_tens = OutputShaper.unaryOp(self.ser, a)
1200 attr = ts.TosaSerializerAttribute()
1201
1202 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1203
1204 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1205 return result_tens
1206
1207 # Needs an additional type/input
1208 def build_prelu(self, op, a):
1209 result_tens = OutputShaper.unaryOp(self.ser, a)
1210
1211 self.ser.addOperator(op, [a.name], [result_tens.name])
1212 return result_tens
1213
1214 def build_relun(self, op, a):
1215 result_tens = OutputShaper.unaryOp(self.ser, a)
1216
1217 attr = ts.TosaSerializerAttribute()
1218
1219 if a.dtype == DType.FLOAT:
1220 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1221 else:
1222 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1223
1224 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1225 return result_tens
1226
1227 def build_sigmoid(self, op, a):
1228 result_tens = OutputShaper.unaryOp(self.ser, a)
1229 self.ser.addOperator(op, [a.name], [result_tens.name])
1230 return result_tens
1231
1232 def build_tanh(self, op, a):
1233 result_tens = OutputShaper.unaryOp(self.ser, a)
1234 self.ser.addOperator(op, [a.name], [result_tens.name])
1235 return result_tens
1236
1237 def build_concat(self, op, a, b, axis):
1238 result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
1239
1240 attr = ts.TosaSerializerAttribute()
1241 attr.AxisAttribute(axis)
1242
1243 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1244
1245 def build_pad(self, op, a, padding, qinfo):
1246 result_tens = OutputShaper.padOp(self.ser, a, padding)
1247
1248 # Need to turn the padding array into a TOSA tensor here.
1249 # This is one of the few tensor operands that does not get
1250 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001251 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001252
Kevin Cheng550ccc52021-03-03 11:21:43 -08001253 self.ser.addOperator(
1254 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1255 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001256
1257 def build_reshape(self, op, a, newShape):
1258 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1259
1260 attr = ts.TosaSerializerAttribute()
1261 attr.ReshapeAttribute(newShape)
1262
1263 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1264 return result_tens
1265
1266 def build_reverse(self, op, a, axis):
1267 result_tens = OutputShaper.unaryOp(self.ser, a)
1268
1269 attr = ts.TosaSerializerAttribute()
1270 attr.AxisAttribute(axis)
1271
1272 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1273 return result_tens
1274
1275 def build_transpose(self, op, a, perms):
1276 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1277
Kevin Cheng550ccc52021-03-03 11:21:43 -08001278 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001279
1280 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1281 return result_tens
1282
1283 def build_slice(self, op, a, begin, size):
1284 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1285
1286 attr = ts.TosaSerializerAttribute()
1287 attr.SliceAttribute(begin, size)
1288
1289 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1290 return result_tens
1291
1292 def build_tile(self, op, a, multiples):
1293 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1294
1295 attr = ts.TosaSerializerAttribute()
1296 attr.TileAttribute(multiples)
1297
1298 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1299 return result_tens
1300
Kevin Cheng77d0f762020-11-24 10:26:32 -08001301 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001302
1303 # Create a new indicies tensor
1304 # here with data that doesn't exceed the dimensions of the values tensor
1305
Kevin Cheng550ccc52021-03-03 11:21:43 -08001306 K = values.shape[1] # K
1307 W = self.randInt(
1308 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1309 ) # W
1310 indicies_arr = np.int32(
1311 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1312 ) # (N, W)
1313 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001314
Kevin Cheng77d0f762020-11-24 10:26:32 -08001315 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001316
Kevin Cheng77d0f762020-11-24 10:26:32 -08001317 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
1319 return result_tens
1320
Kevin Cheng77d0f762020-11-24 10:26:32 -08001321 def build_scatter(self, op, values_in, input):
1322
1323 # Create a new indicies tensor
1324 # here with data that doesn't exceed the dimensions of the values_in tensor
1325
Kevin Cheng550ccc52021-03-03 11:21:43 -08001326 K = values_in.shape[1] # K
1327 W = input.shape[1] # W
1328 indicies_arr = np.int32(
1329 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1330 ) # (N, W)
1331 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001332
1333 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1334
Kevin Cheng550ccc52021-03-03 11:21:43 -08001335 self.ser.addOperator(
1336 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1337 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001338
1339 return result_tens
1340
Kevin Cheng550ccc52021-03-03 11:21:43 -08001341 def build_resize(
1342 self,
1343 op,
1344 input,
1345 mode,
1346 stride,
1347 offset,
1348 shift,
1349 stride_fp,
1350 offset_fp,
1351 output_dims,
1352 input_dtype,
1353 output_dtype,
1354 ):
1355 result_tens = OutputShaper.resizeOp(
1356 self.ser,
1357 input,
1358 mode,
1359 stride,
1360 offset,
1361 shift,
1362 stride_fp,
1363 offset_fp,
1364 output_dims,
1365 input_dtype,
1366 output_dtype,
1367 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001368
1369 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001370
Kevin Cheng550ccc52021-03-03 11:21:43 -08001371 attr.ResizeAttribute(
1372 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1373 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001374
1375 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1376 return result_tens
1377
1378 def build_identityn(self, op, val, val2):
1379
Kevin Cheng550ccc52021-03-03 11:21:43 -08001380 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001381 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001382 self.ser.addOperator(
1383 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1384 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001385 return result_tens
1386
1387 def build_placeholder(self, op, val):
1388 # Add an identity op to avoid warning in the reference model
1389 return self.build_unary(Op.IDENTITY, val)
1390
1391 # Type Conversion
1392 def build_cast(self, op, val, out_dtype):
1393 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1394 self.ser.addOperator(op, [val.name], [result_tens.name])
1395 return result_tens
1396
1397 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1398 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1399
1400 if per_channel:
1401 nc = val.shape[-1]
1402 else:
1403 nc = 1
1404
1405 in_type_width = self.typeWidth(val.dtype)
1406 out_type_width = self.typeWidth(out_dtype)
1407
Kevin Cheng3a478572021-01-22 17:21:02 -08001408 if val.dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001409 input_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001410 in_type_width = in_type_width + 1
1411 else:
1412 input_zp = 0
1413
Kevin Cheng3a478572021-01-22 17:21:02 -08001414 if out_dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001415 output_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001416 out_type_width = out_type_width + 1
1417 else:
1418 output_zp = 0
1419
1420 # Calculate scale based on:
1421 # scale = a *(2^output_width)/(2^input_width))
1422
1423 a = np.float32(self.rng.random(size=[nc]))
1424 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1425
1426 if scale32:
1427 pass
1428 # Cap the scaling at 2^15 - 1 for scale16
1429 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1430 else:
1431 # Cap the scaling at 2^15 - 1 for scale16
1432 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1433
Kevin Cheng550ccc52021-03-03 11:21:43 -08001434 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001435
1436 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1437 shift_arr = np.int32(np.zeros(shape=[nc]))
1438
1439 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001440 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1441 scale_arr[i], scale32
1442 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001443 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001444 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001445
Kevin Cheng550ccc52021-03-03 11:21:43 -08001446 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
1448 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001449 attr.RescaleAttribute(
1450 input_zp,
1451 output_zp,
1452 multiplier_arr,
1453 shift_arr,
1454 scale32,
1455 double_round,
1456 per_channel,
1457 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001458
1459 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1460 return result_tens
1461
1462 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1463 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1464 # (except for the generated shap) and the condition. Build Then/Else blocks
1465 # and fill them with const nodes for the body.
1466
1467 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001468 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001469
1470 # Make then/else tensors
1471 out_shape = then_tens.shape
1472 then_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1473 else_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1474
1475 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001476 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001477
1478 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001479 then_block = "THEN_BLOCK"
1480 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001481 attr = ts.TosaSerializerAttribute()
1482 attr.CondIfAttribute(then_block, else_block)
1483
1484 # Finally, build the op and the two blocks
1485 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1486
1487 self.ser.startBasicBlock(then_block)
1488 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001489 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001490 self.ser.addOutputTensor(then_tens)
1491
1492 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001493 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001494 self.ser.addOutputTensor(else_tens)
1495
1496 return result_tens
1497
1498 def build_cond_if_binary(self, op, a, b, cond):
1499 # For cond_if with a binary op in the then/else blocks, take a and b and
1500 # alternately add or subtract them based on the condition
1501
1502 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001503 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001504
Kevin Cheng550ccc52021-03-03 11:21:43 -08001505 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001506 self.ser.currBasicBlock.addOutput(result_tens.name)
1507
1508 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001509 then_block = "THEN_BLOCK"
1510 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001511 attr = ts.TosaSerializerAttribute()
1512 attr.CondIfAttribute(then_block, else_block)
1513
1514 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001515 self.ser.addOperator(
1516 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1517 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001518
1519 self.ser.startBasicBlock(then_block)
1520 self.ser.addInputTensor(a)
1521 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001522 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001523 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1524
1525 self.ser.startBasicBlock(else_block)
1526 self.ser.addInputTensor(a)
1527 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001528 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001529 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1530
1531 return result_tens
1532
1533 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001534 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001535
Kevin Cheng550ccc52021-03-03 11:21:43 -08001536 cond_block = "COND_BLOCK"
1537 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001538
1539 attr = ts.TosaSerializerAttribute()
1540 attr.WhileLoopAttribute(cond_block, body_block)
1541
1542 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001543 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001544 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001545 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001546
1547 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001548 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1549 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1550 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001551
1552 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001553 self.ser.addOperator(
1554 op,
1555 [iter.name, a.name, acc.name],
1556 [iter_out.name, a_out.name, acc_out.name],
1557 attr,
1558 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001559
1560 # COND block (input: iter, output: cond_tens )
1561 self.ser.startBasicBlock(cond_block)
1562 self.ser.addInputTensor(iter)
1563 self.ser.addInputTensor(a)
1564 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001565 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1566 cond_tens = self.ser.addOutput([], DType.BOOL)
1567 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001568
1569 # BODY block (input: a, acc, iter, output: a, acc, iter)
1570 # Note that local intermediate tensors need to be declared here for the outputs
1571 self.ser.startBasicBlock(body_block)
1572 self.ser.addInputTensor(iter)
1573 self.ser.addInputTensor(a)
1574 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001575 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1576 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1577 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001578 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1579 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1580 self.ser.addOutputTensor(iter_body_out)
1581 self.ser.addOutputTensor(a)
1582 self.ser.addOutputTensor(acc_body_out)
1583
1584 return acc_out
1585
Kevin Cheng550ccc52021-03-03 11:21:43 -08001586 def genOpTestList(
1587 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1588 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001589
1590 try:
1591 op = self.TOSA_OP_LIST[opName]
1592 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001593 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001594
1595 # Initialize a new random number generator
1596 self.rng = np.random.default_rng(self.random_seed)
1597
Kevin Cheng550ccc52021-03-03 11:21:43 -08001598 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001599
1600 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001601 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001602
1603 # Test list consists of a tuple of:
1604 # (opName, testNameStr, dtype, shapeList, argumentsList)
1605 testList = []
1606
1607 if not shapeFilter:
1608 shapeFilter = [None]
1609
1610 for r in range(rmin, rmax + 1):
1611
1612 # Filter out the rank?
1613 if rankFilter is not None and r not in rankFilter:
1614 continue
1615
Kevin Cheng550ccc52021-03-03 11:21:43 -08001616 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001617
1618 # Filter tests based on dtype?
1619 if dtypeFilter is not None:
1620 if t not in dtypeFilter:
1621 continue
1622
1623 # Create the placeholder and const tensors
1624 for shape in shapeFilter:
1625 # A None shape chooses a random shape of a given rank
1626
1627 # Filter out by rank
1628 if shape is not None and len(shape) != r:
1629 continue
1630
1631 self.setTargetShape(shape)
1632 shapeList = tgen_fcn(self, op, r)
1633
1634 shapeStr = self.shapeStr(shapeList[0])
1635 typeStr = self.typeStr(t)
1636
1637 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1638 argList = []
1639 if agen_fcn:
1640 argList = agen_fcn(self, opName, shapeList, t)
1641 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001642 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001643
1644 for argStr, args in argList:
1645 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001646 testStr = "{}_{}_{}_{}".format(
1647 opName, shapeStr, typeStr, argStr
1648 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001649 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001650 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001651
1652 testList.append((opName, testStr, t, shapeList, args))
1653
1654 return testList
1655
Kevin Cheng989cb052021-04-28 16:29:44 -07001656 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001657 try:
1658 op = self.TOSA_OP_LIST[opName]
1659 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001660 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001661
1662 # Create a serializer
1663 self.createSerializer(opName, testStr)
1664
Kevin Cheng550ccc52021-03-03 11:21:43 -08001665 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1666 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001667 num_operands = pCount + cCount
1668
1669 if isinstance(dtype_or_dtypeList, list):
1670 dtypeList = dtype_or_dtypeList
1671 else:
1672 dtypeList = [dtype_or_dtypeList] * (num_operands)
1673
1674 assert (
1675 len(shapeList) == num_operands
1676 ), "shapeList length {} must match number of operands {}".format(
1677 len(shapeList), num_operands
1678 )
1679 assert (
1680 len(dtypeList) == num_operands
1681 ), "dtypeList length {} must match number of operands {}".format(
1682 len(dtypeList), num_operands
1683 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001684
1685 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001686 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001687 except KeyError:
1688 qgen = None
1689
1690 # Build the random tensor operands and the test
1691 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001692
1693 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001694 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1695 assert (
1696 pCount == 2 and cCount == 0
1697 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001698
1699 placeholders = []
1700 for idx, shape in enumerate(shapeList[:]):
1701 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001702 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001703 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001704 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001705 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001706 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001707 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1708 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001709 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001710 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001711 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07001712 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001713
1714 tens.extend(placeholders)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001715 elif op["op"] == Op.DIV:
1716 assert (
1717 pCount == 2 and cCount == 0
1718 ), "Op.Div must have 2 placeholders, 0 consts"
1719
1720 placeholders = []
1721
1722 # Two invalid cases for Op.DIV:
1723 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07001724 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001725 while True:
1726 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
1727 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
1728
1729 if (divisor_arr == 0).any():
1730 continue
1731
Kevin Cheng47315e12021-05-13 17:41:28 -07001732 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001733 continue
1734
1735 break
1736
1737 placeholders.append(
1738 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1739 )
1740 placeholders.append(
1741 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1742 )
1743
1744 tens.extend(placeholders)
1745 elif op["op"] == Op.MUL:
1746 assert (
1747 pCount == 2 and cCount == 0
1748 ), "Op.MUL must have 2 placeholders, 0 consts"
1749
1750 if dtypeList[0] == DType.FLOAT:
1751 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
1752 else:
1753 placeholders = []
1754
1755 # Make sure multiply result in int32 range
1756 shift = testArgs[0]
1757 if dtypeList[0] == DType.INT8:
1758 num_bits = 8
1759 elif dtypeList[0] == DType.INT16:
1760 num_bits = 16
1761 elif dtypeList[0] == DType.INT32:
1762 num_bits = 32
1763 else:
1764 raise Exception("OpMul: invalid input dtype")
1765
1766 for idx, shape in enumerate(shapeList[:]):
1767 low = -(2 ** (num_bits - 1))
1768 high = (2 ** (num_bits - 1)) - 1
1769
1770 a_arr = np.int32(
1771 self.rng.integers(low=low, high=high, size=shapeList[0])
1772 )
1773 b_arr = np.int32(
1774 self.rng.integers(low=low, high=high, size=shapeList[1])
1775 )
1776
1777 i = 0
1778 while True:
1779
1780 a_arr_64 = a_arr.astype(np.int64)
1781 b_arr_64 = b_arr.astype(np.int64)
1782
1783 if shift > 0:
1784 rounding = 1 << (shift - 1)
1785 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
1786 else:
1787 result_arr = a_arr_64 * b_arr_64
1788
1789 if (result_arr > -(2 ** 31)).all() and (
1790 result_arr <= ((2 ** 31) - 1)
1791 ).all():
1792 break
1793
1794 i = i + 1
1795 a_arr = a_arr // 2
1796 b_arr = b_arr // 2
1797
1798 placeholders.append(
1799 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1800 )
1801 placeholders.append(
1802 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1803 )
1804
1805 tens.extend(placeholders)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001806 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001807 tens.extend(
1808 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1809 )
1810 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001811
1812 if qgen is not None:
Kevin Cheng989cb052021-04-28 16:29:44 -07001813 qinfo = qgen(self, op, dtypeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001814 else:
1815 qinfo = None
1816
1817 try:
1818 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001819 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001820 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001821 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001822 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001823 print(
1824 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1825 build_fcn, tens, testArgs
1826 )
1827 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001828 raise e
1829
1830 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001831 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001832
1833 def createDynamicOpLists(self):
1834
1835 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001836 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001837
1838 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001839 testName = "conv2d_{}x{}".format(k[0], k[1])
1840 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1841 self.TOSA_OP_LIST[testName]["filter"] = k
1842 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001843
Kevin Cheng550ccc52021-03-03 11:21:43 -08001844 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1845 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1846 "depthwise_conv2d_TEMPLATE"
1847 ].copy()
1848 self.TOSA_OP_LIST[testName]["filter"] = k
1849 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001850
Kevin Cheng550ccc52021-03-03 11:21:43 -08001851 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1852 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1853 "transpose_conv2d_TEMPLATE"
1854 ].copy()
1855 self.TOSA_OP_LIST[testName]["filter"] = k
1856 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001857
1858 # Delete any templates after having created any dynamic ops
1859 # This is a two-pass operation because it's bad practice to delete
1860 # keys from dictionaries while iterating
1861 keyList = []
1862 for k in self.TOSA_OP_LIST:
1863 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001864 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001865 keyList.append(k)
1866 continue
1867 except KeyError:
1868 pass
1869
1870 for k in keyList:
1871 del self.TOSA_OP_LIST[k]
1872
1873 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001874 """Fill in default fields for ops if they aren't already specified.
1875 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001876 for op in self.TOSA_OP_LIST:
1877
1878 # Required fields
1879 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001880 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001881 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001882 raise Exception(
1883 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
1884 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001885
1886 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001887 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001888 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001889 raise Exception(
1890 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
1891 op
1892 )
1893 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001894
1895 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001896 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001897 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001898 raise Exception(
1899 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
1900 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001901
1902 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001903 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001904 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001905 raise Exception(
1906 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
1907 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
1909 # Put in default rank range, if missing
1910 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001911 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001912 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001913 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07001914
1915 # Tensor operator list
1916 # 'op': op name
1917 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08001918 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
1919 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001920 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
1921 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08001922 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001923
Kevin Cheng550ccc52021-03-03 11:21:43 -08001924 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
1925 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07001926
Kevin Cheng550ccc52021-03-03 11:21:43 -08001927 TYPE_BOOL = [DType.BOOL]
1928 TYPE_FI32 = [DType.FLOAT, DType.INT32]
1929 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
1930 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07001931
Kevin Cheng550ccc52021-03-03 11:21:43 -08001932 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001933
Kevin Cheng989cb052021-04-28 16:29:44 -07001934 TYPE_CONV2D = [
1935 [DType.INT8, DType.INT8, DType.INT32],
1936 [DType.INT16, DType.INT8, DType.INT48],
1937 DType.FLOAT,
1938 ]
1939
Eric Kunzee5e26762020-10-13 16:11:07 -07001940 DEFAULT_RANK_RANGE = (1, 4)
1941
1942 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08001943 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 "argmax": {
1945 "op": Op.ARGMAX,
1946 "operands": (1, 0),
1947 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
1948 "types": TYPE_NARROW_INT_FP,
1949 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001950 "avg_pool2d": {
1951 "op": Op.AVG_POOL2D,
1952 "operands": (1, 0),
1953 "rank": (4, 4),
1954 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
1955 "qgen": TosaQuantGen.qgUnary,
1956 "types": TYPE_NARROW_INT_FP,
1957 },
Eric Kunzee5e26762020-10-13 16:11:07 -07001958 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08001959 "conv2d_TEMPLATE": {
1960 "op": Op.CONV2D,
1961 "operands": (1, 2),
1962 "rank": (4, 4),
1963 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
1964 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07001965 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001966 "template": True,
1967 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001968 # Conv3d TBD
Eric Kunzee5e26762020-10-13 16:11:07 -07001969 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08001970 "depthwise_conv2d_TEMPLATE": {
1971 "op": Op.DEPTHWISE_CONV2D,
1972 "operands": (1, 2),
1973 "filter": [1, 1],
1974 "rank": (4, 4),
1975 "build_fcn": (
1976 build_depthwise_conv2d,
1977 TosaTensorGen.tgDepthwiseConv2D,
1978 TosaArgGen.agConv2D,
1979 ),
1980 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07001981 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001982 "template": True,
1983 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001984 "fully_connected": {
1985 "op": Op.FULLY_CONNECTED,
1986 "operands": (1, 2),
1987 "rank": (2, 2),
1988 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
1989 "qgen": TosaQuantGen.qgConv,
1990 "types": TYPE_CONV2D,
1991 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001992 "matmul": {
1993 "op": Op.MATMUL,
1994 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07001995 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08001996 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
1997 "qgen": TosaQuantGen.qgMatmul,
1998 "types": TYPE_NARROW_INT_FP,
1999 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002000 "max_pool2d": {
2001 "op": Op.MAX_POOL2D,
2002 "operands": (1, 0),
2003 "rank": (4, 4),
2004 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2005 "types": TYPE_NARROW_INT_FP,
2006 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002007 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002008 "transpose_conv2d_TEMPLATE": {
2009 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002010 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002011 "rank": (4, 4),
2012 "build_fcn": (
2013 build_transpose_conv2d,
2014 TosaTensorGen.tgTransposeConv2D,
2015 TosaArgGen.agTransposeConv2D,
2016 ),
2017 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002018 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002019 "template": True,
2020 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002021 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002022 "clamp": {
2023 "op": Op.CLAMP,
2024 "operands": (1, 0),
2025 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2026 "types": TYPE_NARROW_INT_FP,
2027 },
2028 "relun": {
2029 "op": Op.RELUN,
2030 "operands": (1, 0),
2031 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2032 "types": TYPE_FI32,
2033 },
2034 "sigmoid": {
2035 "op": Op.SIGMOID,
2036 "operands": (1, 0),
2037 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2038 "types": TYPE_FP,
2039 },
2040 "tanh": {
2041 "op": Op.TANH,
2042 "operands": (1, 0),
2043 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2044 "types": TYPE_FP,
2045 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002046 # Elementwise Binary Operators
2047 "add": {
2048 "op": Op.ADD,
2049 "operands": (2, 0),
2050 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2051 "types": TYPE_FI32,
2052 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002053 "arithmetic_right_shift": {
2054 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2055 "operands": (2, 0),
2056 "build_fcn": (
2057 build_arithmetic_right_shift,
2058 TosaTensorGen.tgBroadcastFuzz,
2059 TosaArgGen.agArithmeticRightShift,
2060 ),
2061 "types": TYPE_INT,
2062 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002063 "bitwise_and": {
2064 "op": Op.BITWISE_AND,
2065 "operands": (2, 0),
2066 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2067 "types": TYPE_INT,
2068 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002069 "bitwise_or": {
2070 "op": Op.BITWISE_OR,
2071 "operands": (2, 0),
2072 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2073 "types": TYPE_INT,
2074 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002075 "bitwise_xor": {
2076 "op": Op.BITWISE_XOR,
2077 "operands": (2, 0),
2078 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2079 "types": TYPE_INT,
2080 },
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002081 "div": {
2082 "op": Op.DIV,
2083 "operands": (2, 0),
2084 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2085 "types": [DType.INT32],
2086 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002087 "logical_and": {
2088 "op": Op.LOGICAL_AND,
2089 "operands": (2, 0),
2090 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2091 "types": TYPE_BOOL,
2092 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002093 "logical_left_shift": {
2094 "op": Op.LOGICAL_LEFT_SHIFT,
2095 "operands": (2, 0),
2096 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2097 "types": TYPE_INT,
2098 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002099 "logical_right_shift": {
2100 "op": Op.LOGICAL_RIGHT_SHIFT,
2101 "operands": (2, 0),
2102 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2103 "types": TYPE_INT,
2104 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002105 "logical_or": {
2106 "op": Op.LOGICAL_OR,
2107 "operands": (2, 0),
2108 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2109 "types": TYPE_BOOL,
2110 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002111 "logical_xor": {
2112 "op": Op.LOGICAL_XOR,
2113 "operands": (2, 0),
2114 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2115 "types": TYPE_BOOL,
2116 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002117 "maximum": {
2118 "op": Op.MAXIMUM,
2119 "operands": (2, 0),
2120 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2121 "types": TYPE_FI32,
2122 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002123 "minimum": {
2124 "op": Op.MINIMUM,
2125 "operands": (2, 0),
2126 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2127 "types": TYPE_FI32,
2128 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002129 "mul": {
2130 "op": Op.MUL,
2131 "operands": (2, 0),
2132 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2133 "types": TYPE_INT_FP,
2134 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002135 "pow": {
2136 "op": Op.POW,
2137 "operands": (2, 0),
2138 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2139 "types": TYPE_FP,
2140 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002141 "sub": {
2142 "op": Op.SUB,
2143 "operands": (2, 0),
2144 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2145 "types": TYPE_FI32,
2146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002147 "table": {
2148 "op": Op.TABLE,
2149 # Use the automatic generation functions to create the input array
2150 # but create the table tensor in the build function, as it may be
2151 # a different type from the input
2152 "operands": (1, 0),
2153 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
2154 "types": [DType.INT16],
2155 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002156 # Elementwise Unary operators
2157 "abs": {
2158 "op": Op.ABS,
2159 "operands": (1, 0),
2160 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2161 "types": TYPE_FI32,
2162 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002163 "bitwise_not": {
2164 "op": Op.BITWISE_NOT,
2165 "operands": (1, 0),
2166 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2167 "types": TYPE_INT,
2168 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002169 "ceil": {
2170 "op": Op.CEIL,
2171 "operands": (1, 0),
2172 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2173 "types": TYPE_FP,
2174 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002175 "clz": {
2176 "op": Op.CLZ,
2177 "operands": (1, 0),
2178 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2179 "types": [DType.INT32],
2180 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002181 "exp": {
2182 "op": Op.EXP,
2183 "operands": (1, 0),
2184 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2185 "types": TYPE_FP,
2186 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002187 "floor": {
2188 "op": Op.FLOOR,
2189 "operands": (1, 0),
2190 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2191 "types": TYPE_FP,
2192 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002193 "log": {
2194 "op": Op.LOG,
2195 "operands": (1, 0),
2196 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2197 "types": TYPE_FP,
2198 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002199 "logical_not": {
2200 "op": Op.LOGICAL_NOT,
2201 "operands": (1, 0),
2202 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2203 "types": TYPE_BOOL,
2204 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002205 "negate": {
2206 "op": Op.NEGATE,
2207 "operands": (1, 0),
2208 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2209 "qgen": TosaQuantGen.qgUnary,
2210 "types": TYPE_INT_FP,
2211 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002212 "reciprocal": {
2213 "op": Op.RECIPROCAL,
2214 "operands": (1, 0),
2215 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2216 "types": TYPE_FP,
2217 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002218 "rsqrt": {
2219 "op": Op.RSQRT,
2220 "operands": (1, 0),
2221 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2222 "types": TYPE_FP,
2223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002224 # Elementwise Ternary operators
2225 "select": {
2226 "op": Op.SELECT,
2227 "operands": (3, 0),
2228 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2229 "types": TYPE_FIB,
2230 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002231 # Comparison operators
2232 "equal": {
2233 "op": Op.EQUAL,
2234 "operands": (2, 0),
2235 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2236 "types": TYPE_FI32,
2237 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002238 "greater_equal": {
2239 "op": Op.GREATER_EQUAL,
2240 "operands": (2, 0),
2241 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2242 "types": TYPE_FI32,
2243 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002244 "greater": {
2245 "op": Op.GREATER,
2246 "operands": (2, 0),
2247 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2248 "types": TYPE_FI32,
2249 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002250 # Reduction operators
2251 "reduce_all": {
2252 "op": Op.REDUCE_ALL,
2253 "operands": (1, 0),
2254 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2255 "types": TYPE_BOOL,
2256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002257 "reduce_any": {
2258 "op": Op.REDUCE_ANY,
2259 "operands": (1, 0),
2260 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2261 "types": TYPE_BOOL,
2262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002263 "reduce_max": {
2264 "op": Op.REDUCE_MAX,
2265 "operands": (1, 0),
2266 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2267 "types": TYPE_INT_FP,
2268 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002269 "reduce_min": {
2270 "op": Op.REDUCE_MAX,
2271 "operands": (1, 0),
2272 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2273 "types": TYPE_INT_FP,
2274 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002275 "reduce_product": {
2276 "op": Op.REDUCE_PRODUCT,
2277 "operands": (1, 0),
2278 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2279 "types": TYPE_FP,
2280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002281 "reduce_sum": {
2282 "op": Op.REDUCE_SUM,
2283 "operands": (1, 0),
2284 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2285 "types": TYPE_FI32,
2286 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002287 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002288 "concat": {
2289 "op": Op.CONCAT,
2290 "operands": (2, 0),
2291 "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2292 "types": TYPE_FIB,
2293 },
2294 "pad": {
2295 "op": Op.PAD,
2296 "operands": (1, 0),
2297 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2298 "qgen": TosaQuantGen.qgPad,
2299 "types": TYPE_FIB,
2300 },
2301 "reshape": {
2302 "op": Op.RESHAPE,
2303 "operands": (1, 0),
2304 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2305 "types": TYPE_FIB,
2306 },
2307 "reverse": {
2308 "op": Op.REVERSE,
2309 "operands": (1, 0),
2310 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2311 "types": TYPE_FIB,
2312 },
2313 "slice": {
2314 "op": Op.SLICE,
2315 "operands": (1, 0),
2316 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2317 "types": TYPE_FIB,
2318 },
2319 "tile": {
2320 "op": Op.TILE,
2321 "operands": (1, 0),
2322 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2323 "types": TYPE_FIB,
2324 },
2325 "transpose": {
2326 "op": Op.TRANSPOSE,
2327 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01002328 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002329 "build_fcn": (
2330 build_transpose,
2331 TosaTensorGen.tgBasic,
2332 TosaArgGen.agTranspose,
2333 ),
2334 "types": TYPE_FIB,
2335 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002336 # Data nodes
2337 "const": {
2338 "op": Op.CONST,
2339 "operands": (1, 0),
2340 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2341 "types": TYPE_FIB,
2342 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002343 "identity": {
2344 "op": Op.IDENTITY,
2345 "operands": (1, 0),
2346 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2347 "types": TYPE_FIB,
2348 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002349 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002350 "gather": {
2351 "op": Op.GATHER,
2352 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2353 "operands": (1, 0),
2354 "rank": (3, 3),
2355 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2356 "types": TYPE_INT_FP,
2357 },
2358 "scatter": {
2359 "op": Op.SCATTER,
2360 # Only specify 'values_in' tensor here.
2361 #'indices' and 'input' are generated in op building stage
2362 "operands": (2, 0),
2363 "rank": (3, 3),
2364 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2365 "types": TYPE_INT_FP,
2366 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002367 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002368 "resize": {
2369 "op": Op.RESIZE,
2370 "operands": (1, 0),
2371 "rank": (4, 4),
2372 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2373 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2374 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002375 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002376 "cast": {
2377 "op": Op.CAST,
2378 "operands": (1, 0),
2379 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2380 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2381 },
2382 "rescale": {
2383 "op": Op.RESCALE,
2384 "operands": (1, 0),
2385 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
2386 "types": [DType.INT8, DType.INT16, DType.INT32, DType.INT48],
2387 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002388 # Custom
2389 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002390 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002391 # Two varients of cond_if, one that generates one of two constant tensors (no
2392 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2393 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002394 "cond_if_const": {
2395 "op": Op.COND_IF,
2396 "operands": (0, 2),
2397 "build_fcn": (
2398 build_cond_if_const,
2399 TosaTensorGen.tgBasic,
2400 TosaArgGen.agCondIf,
2401 ),
2402 "types": [DType.BOOL],
2403 },
2404 "cond_if_binary": {
2405 "op": Op.COND_IF,
2406 "operands": (2, 0),
2407 "build_fcn": (
2408 build_cond_if_binary,
2409 TosaTensorGen.tgBasic,
2410 TosaArgGen.agCondIf,
2411 ),
2412 "types": TYPE_FI32,
2413 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002414 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002415 "while_loop": {
2416 "op": Op.WHILE_LOOP,
2417 "operands": (0, 1),
2418 "build_fcn": (
2419 build_while_loop,
2420 TosaTensorGen.tgBasic,
2421 TosaArgGen.agWhileLoop,
2422 ),
2423 "types": [DType.INT32],
2424 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002425 }
2426
Kevin Cheng550ccc52021-03-03 11:21:43 -08002427
Eric Kunzee5e26762020-10-13 16:11:07 -07002428class OutputShaper:
2429 # Methods in this class compute the expected output shape and datatype
2430 # for common classes of operations
2431 def __init__(self):
2432 pass
2433
2434 # These methods return arguments that can be used for
2435 # creating a new output tensor
2436 @staticmethod
2437 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002438 assert len(a.shape) == len(b.shape)
2439 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002440
2441 shape = []
2442 for i in range(len(a.shape)):
2443 if a.shape[i] == 1:
2444 shape.append(b.shape[i])
2445 else:
2446 shape.append(a.shape[i])
2447
Kevin Cheng550ccc52021-03-03 11:21:43 -08002448 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002449
2450 @staticmethod
2451 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002452 assert len(a.shape) == len(b.shape)
2453 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002454
2455 shape = []
2456 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002457 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002458 shape.append(a.shape[i])
2459
Kevin Cheng550ccc52021-03-03 11:21:43 -08002460 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002461
2462 @staticmethod
2463 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002464 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002465
2466 @staticmethod
2467 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002468 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2469 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002470
2471 shape = []
2472 for i in range(len(a.shape)):
2473 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2474
Kevin Cheng550ccc52021-03-03 11:21:43 -08002475 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002476
2477 @staticmethod
2478 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002479 assert len(a.shape) == len(b.shape)
2480 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002481
2482 # Do broadcast
2483 shape = []
2484 for i in range(len(a.shape)):
2485 if a.shape[i] == 1:
2486 shape.append(b.shape[i])
2487 else:
2488 shape.append(a.shape[i])
2489
2490 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002491 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002492
2493 @staticmethod
2494 def reduceOp(ser, a, axis):
2495
2496 shape = a.shape.copy()
2497
2498 shape[axis] = 1
2499
Kevin Cheng550ccc52021-03-03 11:21:43 -08002500 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002501
2502 @staticmethod
2503 def argmaxOp(ser, a, axis):
2504 shape = a.shape.copy()
2505 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002506 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002507
2508 @staticmethod
2509 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2510
2511 # IFM: NHWC
2512 # Filter: OHWI
2513 # OFM: NHWC
2514
2515 if len(padding) == 2:
2516 # Expand padding to 4 parameters in the case of transpose_conv2d
2517 # From H,W to T,B,L,R
2518 padding = [padding[0], padding[0], padding[1], padding[1]]
2519
Kevin Cheng550ccc52021-03-03 11:21:43 -08002520 h = (
2521 ifm.shape[1]
2522 - filter.shape[1]
2523 - (filter.shape[1] - 1) * (dilations[0] - 1)
2524 + padding[0]
2525 + padding[1]
2526 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002527
Kevin Cheng550ccc52021-03-03 11:21:43 -08002528 w = (
2529 ifm.shape[2]
2530 - filter.shape[2]
2531 - (filter.shape[2] - 1) * (dilations[1] - 1)
2532 + padding[2]
2533 + padding[3]
2534 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002535
2536 if h <= 0 or w <= 0:
2537 # Invalid test parameters?
2538 h = 0
2539 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002540 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002541
2542 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2543
Kevin Cheng3a478572021-01-22 17:21:02 -08002544 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002545 out_dtype = DType.INT32
2546 elif ifm.dtype == DType.INT16:
2547 out_dtype = DType.INT48
2548 elif ifm.dtype == DType.FLOAT:
2549 out_dtype = DType.FLOAT
2550 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002551 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002552
Kevin Cheng550ccc52021-03-03 11:21:43 -08002553 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002554
2555 @staticmethod
2556 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2557 # IFM: NHWC
2558 # Filter: HWCM
2559 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002560 h = (
2561 ifm.shape[1]
2562 - filter.shape[0]
2563 - (filter.shape[0] - 1) * (dilations[0] - 1)
2564 + padding[0]
2565 + padding[1]
2566 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002567
Kevin Cheng550ccc52021-03-03 11:21:43 -08002568 w = (
2569 ifm.shape[2]
2570 - filter.shape[1]
2571 - (filter.shape[1] - 1) * (dilations[1] - 1)
2572 + padding[2]
2573 + padding[3]
2574 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002575
2576 if h <= 0 or w <= 0:
2577 # Invalid test parameters?
2578 h = 0
2579 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002580 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002581
2582 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2583
Kevin Cheng3a478572021-01-22 17:21:02 -08002584 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002585 out_dtype = DType.INT32
2586 elif ifm.dtype == DType.INT16:
2587 out_dtype = DType.INT48
2588 elif ifm.dtype == DType.FLOAT:
2589 out_dtype = DType.FLOAT
2590 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002591 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002592
Kevin Cheng550ccc52021-03-03 11:21:43 -08002593 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002594
2595 @staticmethod
2596 def pool2dOp(ser, ifm, kernel, stride, pad):
2597 # input: NHWC
2598 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2599 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2600
2601 if h <= 0 or w <= 0:
2602 # Invalid test parameters?
2603 h = 0
2604 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002605 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002606
2607 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002608 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002609
2610 @staticmethod
2611 def fullyConnectedOp(ser, input, filter):
2612 # input: N, IC
2613 # filter: OC, IC
2614 # output: N, OC
2615
2616 output_shape = [input.shape[0], filter.shape[0]]
2617
Kevin Cheng3a478572021-01-22 17:21:02 -08002618 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002619 out_dtype = DType.INT32
2620 elif input.dtype == DType.INT16:
2621 out_dtype = DType.INT48
2622 elif input.dtype == DType.FLOAT:
2623 out_dtype = DType.FLOAT
2624 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002625 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002626
Kevin Cheng550ccc52021-03-03 11:21:43 -08002627 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002628
2629 @staticmethod
2630 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07002631 # a: N, H, C
2632 # b: N, C, W
2633 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07002634
Kevin Cheng2d60f002021-06-09 14:18:32 -07002635 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002636
Kevin Cheng3a478572021-01-22 17:21:02 -08002637 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002638 out_dtype = DType.INT32
2639 elif a.dtype == DType.INT16:
2640 out_dtype = DType.INT48
2641 elif a.dtype == DType.FLOAT:
2642 out_dtype = DType.FLOAT
2643 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002644 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002645
Kevin Cheng550ccc52021-03-03 11:21:43 -08002646 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002647
2648 @staticmethod
2649 def concatOp(ser, a, b, axis):
2650
2651 output_shape = a.shape.copy()
2652 output_shape[axis] = a.shape[axis] + b.shape[axis]
2653
Kevin Cheng550ccc52021-03-03 11:21:43 -08002654 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002655
2656 @staticmethod
2657 def padOp(ser, a, padding):
2658
2659 output_shape = a.shape.copy()
2660
2661 for i in range(len(output_shape)):
2662 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2663
Kevin Cheng550ccc52021-03-03 11:21:43 -08002664 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002665
2666 @staticmethod
2667 def reshapeOp(ser, a, shape):
2668 output_shape = shape.copy()
2669
2670 totalElements = 1
2671 for i in a.shape:
2672 totalElements *= i
2673
2674 # If there are any -1 elements, figure out what that dimension must be
2675 totalOutputElements = 1
2676 for i in output_shape:
2677 if i != -1:
2678 totalOutputElements *= i
2679
2680 # And fill it in
2681 for i in range(len(output_shape)):
2682 if output_shape[i] == -1:
2683 output_shape[i] = totalElements // totalOutputElements
2684
Kevin Cheng550ccc52021-03-03 11:21:43 -08002685 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002686
2687 @staticmethod
2688 def sliceOp(ser, a, begin, size):
2689
2690 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002691 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002692
2693 @staticmethod
2694 def tileOp(ser, a, multiples):
2695
2696 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002697 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002698
2699 for i in range(len(output_shape)):
2700 output_shape[i] = a.shape[i] * multiples[i]
2701
Kevin Cheng550ccc52021-03-03 11:21:43 -08002702 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002703
2704 @staticmethod
2705 def transposeOp(ser, a, perms):
2706 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002707 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002708
2709 for i in range(len(output_shape)):
2710 output_shape[i] = a.shape[perms[i]]
2711
Kevin Cheng550ccc52021-03-03 11:21:43 -08002712 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002713
2714 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002715 def gatherOp(ser, values, indices):
2716 assert len(values.shape) == 3
2717 assert len(indices.shape) == 2
2718 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002719
Kevin Cheng77d0f762020-11-24 10:26:32 -08002720 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2721
Kevin Cheng550ccc52021-03-03 11:21:43 -08002722 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002723
2724 @staticmethod
2725 def scatterOp(ser, values_in, indices, input):
2726 assert len(values_in.shape) == 3
2727 assert len(indices.shape) == 2
2728 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002729 assert values_in.shape[0] == indices.shape[0] # N
2730 assert input.shape[1] == indices.shape[1] # W
2731 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002732
2733 output_shape = values_in.shape
2734
Kevin Cheng550ccc52021-03-03 11:21:43 -08002735 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002736
2737 @staticmethod
2738 def tableOp(ser, input, table):
2739 # Same shape as the input, but with the type of the table.
Kevin Cheng550ccc52021-03-03 11:21:43 -08002740 return ser.addOutput(input.shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002741
2742 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002743 def resizeOp(
2744 ser,
2745 input,
2746 mode,
2747 stride,
2748 offset,
2749 shift,
2750 stride_fp,
2751 offset_fp,
2752 output_dims,
2753 input_dtype,
2754 output_dtype,
2755 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002756
2757 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2758
Kevin Cheng77d0f762020-11-24 10:26:32 -08002759 if input_dtype == DType.FLOAT:
2760 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002761 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002762 else:
2763 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002764 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002765
Kevin Chengaee1fac2020-11-11 13:54:06 -08002766 if mode == ResizeMode.BILINEAR:
2767 if input_dtype == DType.INT8:
2768 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002769 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002770 elif input_dtype == DType.INT16:
2771 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002772 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002773 elif input_dtype == DType.FLOAT:
2774 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002775 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002776 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002777 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002778
2779 elif mode == ResizeMode.NEAREST:
2780 if input_dtype == DType.INT8:
2781 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002782 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002783 elif input_dtype == DType.INT16:
2784 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002785 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002786 elif input_dtype == DType.FLOAT:
2787 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002788 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002789 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002790 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002791
2792 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002793 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002794
Kevin Cheng550ccc52021-03-03 11:21:43 -08002795 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002796
2797 @staticmethod
2798 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002799 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002800
2801 @staticmethod
2802 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002803 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002804 out_dtype = DType.INT32
2805 elif ifm.dtype == DType.INT16:
2806 out_dtype = DType.INT48
2807 elif ifm.dtype == DType.FLOAT:
2808 out_dtype = DType.FLOAT
2809 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002810 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002811
2812 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002813 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002814
Kevin Cheng550ccc52021-03-03 11:21:43 -08002815 return ser.addOutput(output_shape, out_dtype)