blob: 6f9acf4bc5242a911a8299ea668b7c63d405c91f [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
32
33from enum import IntEnum, Enum, unique
34
Kevin Cheng550ccc52021-03-03 11:21:43 -080035# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
36parent_dir = os.path.dirname(os.path.realpath(__file__))
37sys.path.append(
38 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
39)
Eric Kunzee5e26762020-10-13 16:11:07 -070040import tosa_serializer as ts
41from tosa_serializer import *
42import tosa
43
44# Convenience variables to the flatc-generated types that should be enums, but aren't
45DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080046Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070047ResizeMode = tosa.ResizeMode.ResizeMode()
48
Kevin Cheng550ccc52021-03-03 11:21:43 -080049
Eric Kunzee5e26762020-10-13 16:11:07 -070050class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080051 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
52
Eric Kunzee5e26762020-10-13 16:11:07 -070053 def __init__(self):
54 pass
55
56 @staticmethod
57 def needsQinfo(op, dtype):
Jared Smolens2a76ad22021-03-04 11:18:54 -080058 if dtype == DType.INT8 or dtype == DType.INT16:
Eric Kunzee5e26762020-10-13 16:11:07 -070059 return True
60 return False
61
62 @staticmethod
63 def qgUnary(testGen, op, dtype):
64 qinfo = ts.TosaSerializerQuantInfo()
65 if TosaQuantGen.needsQinfo(op, dtype):
66 qinfo.UnaryQuantInfo(testGen.randInt(), testGen.randInt())
67 else:
68 qinfo.UnaryQuantInfo(0, 0)
69 return qinfo
70
71 @staticmethod
72 def qgConv(testGen, op, dtype):
73 qinfo = ts.TosaSerializerQuantInfo()
74 if TosaQuantGen.needsQinfo(op, dtype):
75 qinfo.ConvQuantInfo(testGen.randInt(), testGen.randInt())
76 else:
77 qinfo.ConvQuantInfo(0, 0)
78 return qinfo
79
80 @staticmethod
81 def qgMatmul(testGen, op, dtype):
82 qinfo = ts.TosaSerializerQuantInfo()
83 if TosaQuantGen.needsQinfo(op, dtype):
84 qinfo.MatMulQuantInfo(testGen.randInt(), testGen.randInt())
85 else:
86 qinfo.MatMulQuantInfo(0, 0)
87 return qinfo
88
89 @staticmethod
90 def qgPad(testGen, op, dtype):
91 qinfo = ts.TosaSerializerQuantInfo()
92 if TosaQuantGen.needsQinfo(op, dtype):
93 qinfo.PadQuantInfo(testGen.randInt())
94 else:
95 qinfo.PadQuantInfo(0)
96 return qinfo
97
98 @staticmethod
99 def computeMultiplierAndShift(scaleFp, scale32):
100 # Derived from computeMultiplierAndShiftTosaScale32
101 # Provide a floating-point scaling factor and the scale32 parameter
102 # to compute the multiplier and shift
103
104 if scale32:
105 scaleBits = 31
106 else:
107 scaleBits = 15
108
109 m, shift = math.frexp(scaleFp)
110
111 if scaleFp < 0.0:
112 m = -m
113
114 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800115 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700116
117 if multiplier == (1 << scaleBits):
118 multiplier = multiplier // 2
119 shift = shift + 1
120
121 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800122 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700123
Kevin Cheng550ccc52021-03-03 11:21:43 -0800124 assert multiplier <= (1 << scaleBits)
125 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700126
127 return multiplier, shift
128
129
Kevin Cheng550ccc52021-03-03 11:21:43 -0800130class TosaTensorGen:
131 """Tensor generators create a shape list for the placeholder and const tensor
132 data operands for the operator. The actual random data is generated separately for each test."""
133
Eric Kunzee5e26762020-10-13 16:11:07 -0700134 def __init__(self):
135 pass
136
137 @staticmethod
138 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800139 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700140 shape = testGen.makeShape(rank)
141
142 shape_list = []
143 for i in range(pl + const):
144 shape_list.append(shape.copy())
145
146 return shape_list
147
148 @staticmethod
149 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800150 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700151
Kevin Cheng550ccc52021-03-03 11:21:43 -0800152 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700153
154 shape = testGen.makeShape(rank)
155
156 # Constrict the batch size?
157 if testGen.args.max_batch_size:
158 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
159
160 shape_list = []
161 for i in range(pl + const):
162 shape_list.append(shape.copy())
163
164 return shape_list
165
166 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800167 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800168 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800169
Kevin Cheng550ccc52021-03-03 11:21:43 -0800170 assert pl == 2
171 assert const == 0
172 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800173
174 values_in_shape = testGen.makeShape(rank)
175
176 # Constrict the batch size?
177 if testGen.args.max_batch_size:
178 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
179
Kevin Cheng550ccc52021-03-03 11:21:43 -0800180 W = testGen.randInt(
181 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
182 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800183 input_shape = [values_in_shape[0], W, values_in_shape[2]]
184
185 shape_list = []
186 shape_list.append(values_in_shape.copy())
187 shape_list.append(input_shape.copy())
188
189 return shape_list
190
191 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 def tgBroadcastFuzz(testGen, op, rank):
193 shape = testGen.makeShape(rank)
194
Kevin Cheng550ccc52021-03-03 11:21:43 -0800195 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700196
197 shape_list = []
198
199 # Choose one of the inputs to broadcast
200 bcast_idx = testGen.randInt(0, pl + const)
201 for i in range(pl + const):
202 shape_bcast = shape.copy()
203
204 # If the chosen input, pick a random index to broadcast
205 if i == bcast_idx:
206 fuzz_idx = testGen.randInt(0, rank)
207 shape_bcast[fuzz_idx] = 1
208
209 shape_list.append(shape_bcast)
210
211 return shape_list
212
213 @staticmethod
214 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800215 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700216
Kevin Cheng550ccc52021-03-03 11:21:43 -0800217 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700218
219 # IFM dimensions are NHWC
220 ifm_shape = testGen.makeShape(rank)
221
222 # Constrict the batch size?
223 if testGen.args.max_batch_size:
224 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
225
226 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800227 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 # Generate a random OFM depth
230 ofm_depth = testGen.makeShape(1)[0]
231
232 # The filter dimensions are OHWI
233 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
234
235 # The bias is OC
236 bias_shape = np.asarray([ofm_depth])
237
238 return [ifm_shape, filter_shape, bias_shape]
239
240 @staticmethod
241 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800242 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700243
Kevin Cheng550ccc52021-03-03 11:21:43 -0800244 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700245
246 # IFM dimensions are NHWC
247 ifm_shape = testGen.makeShape(rank)
248
249 # Constrict the batch size?
250 if testGen.args.max_batch_size:
251 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
252
253 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800254 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700255
256 # Generate a random OFM depth
257 ofm_depth = testGen.makeShape(1)[0]
258
259 # The filter dimensions are OHWI
260 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
261
Kevin Cheng989cb052021-04-28 16:29:44 -0700262 # The bias is OC
263 bias_shape = np.asarray([ofm_depth])
264
265 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700266
267 @staticmethod
268 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800269 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700270
Kevin Cheng550ccc52021-03-03 11:21:43 -0800271 assert rank == 4
272 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700273
274 # IFM dimensions are NHWC
275 ifm_shape = testGen.makeShape(rank)
276
277 # Constrict the batch size?
278 if testGen.args.max_batch_size:
279 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
280
281 # Get the filter height/width from the operator parameters
282 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800283 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700284
285 # Generate a random OFM depth, but don't let it get too big because
286 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800287 filter_m = (
288 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
289 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700290
291 # The filter dimensions are HWCM
292 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
293
294 # The bias is M * C
295 bias_shape = np.asarray([ifm_shape[3] * filter_m])
296
297 return [ifm_shape, filter_shape, bias_shape]
298
299 @staticmethod
300 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800301 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700302
Kevin Cheng550ccc52021-03-03 11:21:43 -0800303 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700304
305 input_shape = testGen.makeShape(rank)
306 filter_oc = testGen.makeShape(1)[0]
307 filter_shape = np.asarray([filter_oc, input_shape[1]])
308
309 bias_shape = np.asarray([filter_oc])
310
311 return [input_shape, filter_shape, bias_shape]
312
313 @staticmethod
314 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800315 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700316
Kevin Cheng2d60f002021-06-09 14:18:32 -0700317 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800318 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
320 a_shape = testGen.makeShape(rank)
321 b_oc = testGen.makeShape(1)[0]
Kevin Cheng2d60f002021-06-09 14:18:32 -0700322 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700323
324 return [a_shape, b_shape]
325
Kevin Cheng550ccc52021-03-03 11:21:43 -0800326
Eric Kunzee5e26762020-10-13 16:11:07 -0700327class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800328 """Argument generators create exhaustive or random lists of attributes for operators that take
329 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
330 tuples where the descriptive_name is appended to the test name and the arglist is expanded
331 as arguments to the operator build function."""
332
Eric Kunzee5e26762020-10-13 16:11:07 -0700333 def __init__(self):
334 pass
335
336 @staticmethod
337 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800338 """A trivial argument generator for operators that don't take any
339 non-tensor arguments"""
340 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700341
342 @staticmethod
343 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800344 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700345 axes = []
346
347 shape = shapeList[0]
348
349 for a in range(0, len(shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800350 axes.append(("axis_{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700351 return axes
352
353 @staticmethod
354 def agConv2D(testGen, opName, shapeList, dtype):
355 arg_list = []
356
357 ifm_shape = shapeList[0]
358 filter_shape = shapeList[1]
359
360 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800361 assert len(ifm_shape) == 4
362 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700363
364 maxStride = testGen.args.max_conv_stride
365 maxPadding = testGen.args.max_conv_padding + 1
366 maxDilation = testGen.args.max_conv_dilation
367
368 # Strides, padding, dilations
369 for stride in range(0, maxStride ** 2):
370 for padding in range(0, (maxPadding) ** 4):
371 for dilation in range(0, maxDilation ** 2):
372
Kevin Cheng550ccc52021-03-03 11:21:43 -0800373 s = [stride // maxStride + 1, stride % maxStride + 1]
374 p = [
375 (padding // (maxPadding * 4)) % maxPadding,
376 (padding // (maxPadding * 2)) % maxPadding,
377 (padding // (maxPadding * 1)) % maxPadding,
378 padding % maxPadding,
379 ]
380 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700381
382 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 arg_list.append(
384 (
385 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
386 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
387 ),
388 [s, p, d],
389 )
390 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 return arg_list
392
393 @staticmethod
394 def agTransposeConv2D(testGen, opName, shapeList, dtype):
395 arg_list = []
396
397 ifm_shape = shapeList[0]
398 filter_shape = shapeList[1]
399
400 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800401 assert len(ifm_shape) == 4
402 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700403
404 maxStride = testGen.args.max_conv_stride
405 maxPadding = testGen.args.max_conv_padding + 1
406 maxDilation = testGen.args.max_conv_dilation
407
408 # Strides, padding, dilations
409 for stride in range(0, maxStride ** 2):
410 for out_padding in range(0, (maxPadding) ** 2):
411 for dilation in range(0, maxDilation ** 2):
412
Kevin Cheng550ccc52021-03-03 11:21:43 -0800413 s = [stride // maxStride + 1, stride % maxStride + 1]
414 p = [
415 (out_padding // (maxPadding * 1)) % maxPadding,
416 out_padding % maxPadding,
417 ]
418 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700419
Kevin Cheng550ccc52021-03-03 11:21:43 -0800420 oh = (
421 ifm_shape[1]
422 - filter_shape[1]
423 - (filter_shape[1] - 1) * (d[0] - 1)
424 + 2 * p[0]
425 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700426
Kevin Cheng550ccc52021-03-03 11:21:43 -0800427 ow = (
428 ifm_shape[2]
429 - filter_shape[2]
430 - (filter_shape[2] - 1) * (d[1] - 1)
431 + 2 * p[1]
432 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
434 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800435 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700436
Kevin Cheng550ccc52021-03-03 11:21:43 -0800437 arg_list.append(
438 (
439 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
440 s[0],
441 s[1],
442 p[0],
443 p[1],
444 d[0],
445 d[1],
446 os[0],
447 os[1],
448 os[2],
449 os[3],
450 ),
451 [s, p, d, os],
452 )
453 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700454
455 return arg_list
456
457 @staticmethod
458 def agPad(testGen, opName, shapeList, dtype):
459 arg_list = []
460 rank = len(shapeList[0])
461
462 # Exhaustively test combinations of 0/1 padding on each side of each dimension
463 # This process might need some revision for >1 padding, but use rank**2 as a bitmask
464 # for now
465 for v in range(rank ** 2):
466
467 # Create a flat arraypadding4D
468 paddings = np.zeros((rank * 2), dtype=np.int32)
469
470 # Fill in the 1's
Kevin Cheng550ccc52021-03-03 11:21:43 -0800471 for r in range(rank * 2):
Eric Kunzee5e26762020-10-13 16:11:07 -0700472 if (v >> r) & 1:
473 paddings[r] = 1
474
475 # Reshape back to a 2D array
476 paddings = paddings.reshape((rank, 2))
477
Kevin Cheng550ccc52021-03-03 11:21:43 -0800478 arg_list.append(("pad{0:b}".format(v), [paddings]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700479
480 return arg_list
481
482 @staticmethod
483 def agPooling(testGen, opName, shapeList, dtype):
484 arg_list = []
485
486 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800487 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700488
489 maxStride = testGen.args.max_pooling_stride
490 maxKernel = testGen.args.max_pooling_kernel
491 maxPadding = testGen.args.max_pooling_padding + 1
492
493 for kernel in range(0, maxKernel ** 2):
494 for stride in range(0, maxStride ** 2):
495 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800496 s = [stride // maxStride + 1, stride % maxStride + 1]
497 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
498 p = [
499 (padding // (maxPadding * 4)) % maxPadding,
500 (padding // (maxPadding * 2)) % maxPadding,
501 (padding // (maxPadding * 1)) % maxPadding,
502 padding % maxPadding,
503 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700504
Kevin Cheng550ccc52021-03-03 11:21:43 -0800505 arg_list.append(
506 (
507 "st{}{}_kern{}{}_pad{}{}{}{}".format(
508 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
509 ),
510 [k, s, p],
511 )
512 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700513 return arg_list
514
515 @staticmethod
516 def agCast(testGen, opName, shapeList, inDtype):
517 arg_list = []
518
519 # Enumerate the output types here
520 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800521 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700522 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800523 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700524 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800525 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700526 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800527 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700528 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800529 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700530 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800531 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
533 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800534 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
536 return arg_list
537
538 @staticmethod
539 def agRescale(testGen, opName, shapeList, inDtype):
540 arg_list = []
541
542 # Enumerate the output types here
Kevin Cheng550ccc52021-03-03 11:21:43 -0800543 for dtype in [DType.INT8, DType.INT16, DType.INT32]:
544 for scale32 in [False, True]:
545 for double_round in [False, True]:
546 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700547
548 if inDtype == DType.INT48 and scale32:
549 # Illegal condition. Must be scale32=False
550 continue
551
Kevin Cheng550ccc52021-03-03 11:21:43 -0800552 arg_list.append(
553 (
554 "out{}_sc{}_dr{}_pc{}".format(
555 DTypeNames[dtype],
556 int(scale32),
557 int(double_round),
558 int(per_channel),
559 ),
560 [dtype, scale32, double_round, per_channel],
561 )
562 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
564 return arg_list
565
Kevin Chengaee1fac2020-11-11 13:54:06 -0800566 @staticmethod
567 def agMul(testGen, opName, shapeList, dtype):
568 arg_list = []
569
570 if dtype is DType.INT32:
571 for p in range(testGen.args.num_rand_permutations):
572
573 shift = testGen.randInt(0, 32)
574
Kevin Cheng550ccc52021-03-03 11:21:43 -0800575 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800576 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800577 arg_list.append(("shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800578
579 return arg_list
580
581 @staticmethod
582 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
583 arg_list = []
584
Kevin Cheng550ccc52021-03-03 11:21:43 -0800585 arg_list.append(("roundTrue", [True]))
586 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800587
588 return arg_list
589
Eric Kunzee5e26762020-10-13 16:11:07 -0700590 # Helper function for reshape. Gets some factors of a larger number.
591 @staticmethod
592 def getFactors(val, start=1):
593 factors = []
594
595 for i in range(start, int(np.sqrt(val))):
596 if (val % i) == 0:
597 factors.append(i)
598
599 return factors
600
601 @staticmethod
602 def agReshape(testGen, opName, shapeList, dtype):
603 arg_list = []
604
605 origShape = shapeList[0]
606
607 totalElements = 1
608 for s in origShape:
609 totalElements *= s
610
611 # This code is NOT fast. Fortunately, the numbers are fairly small.
612 factors = TosaArgGen.getFactors(totalElements)
613
614 for p in range(testGen.args.num_rand_permutations):
615 newRank = testGen.randInt(1, 6)
616 newShape = []
Kevin Cheng550ccc52021-03-03 11:21:43 -0800617 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700618 continue
619
620 remainingElements = totalElements
621 shuffledFactors = testGen.rng.permutation(factors)
622 for i in range(newRank):
623 # pick rank-1 factors
624 newShape.append(shuffledFactors[0])
625 remainingElements = remainingElements // shuffledFactors[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800626 shuffledFactors = testGen.rng.permutation(
627 TosaArgGen.getFactors(remainingElements)
628 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700629 newShape.append(remainingElements)
630
631 # Toss in a -1 sometimes
632 minusOne = testGen.randInt(0, newRank * 4)
633 if minusOne < newRank:
634 newShape[minusOne] = -1
635
Kevin Cheng550ccc52021-03-03 11:21:43 -0800636 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700637
638 return arg_list
639
Eric Kunzee5e26762020-10-13 16:11:07 -0700640 @staticmethod
641 def agTranspose(testGen, opName, shapeList, dtype):
642 arg_list = []
643
644 ifm_shape = shapeList[0]
645
646 perms = range(len(ifm_shape))
647 for p in range(testGen.args.num_rand_permutations):
648 perms = np.int32(testGen.rng.permutation(perms)).tolist()
649
650 # Avoid duplicates
651 found = False
652 for name, other_perm in arg_list:
653 if other_perm[0] == perms:
654 found = True
655 break
656
657 if not found:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800658 arg_list.append(("perm{}".format(p), [perms]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700659
660 return arg_list
661
662 @staticmethod
663 def agSlice(testGen, opName, shapeList, dtype):
664 arg_list = []
665
666 ifm_shape = shapeList[0]
667 rank = len(ifm_shape)
668
669 for p in range(testGen.args.num_rand_permutations):
670 begin = []
671 size = []
672
Kevin Cheng550ccc52021-03-03 11:21:43 -0800673 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700674
675 for i in range(rank):
676 if ifm_shape[i] > 1:
677 begin.append(testGen.randInt(0, ifm_shape[i]))
678 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
679
680 # Invalid slice size?
681 if size[i] == 0:
682 valid = False
683 else:
684 begin.append(0)
685 size.append(1)
686
687 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800688 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700689 return arg_list
690
691 @staticmethod
692 def agTile(testGen, opName, shapeList, dtype):
693 arg_list = []
694
695 ifm_shape = shapeList[0]
696 rank = len(ifm_shape)
697
698 for p in range(testGen.args.num_rand_permutations):
699
700 # Pick a few random, but small multiple values
701 # because otherwise this has a tendency to generate
702 # enormous tensors
703 multiples = []
704 for i in range(rank):
705 multiples.append(testGen.randInt(1, 4))
706
Kevin Cheng550ccc52021-03-03 11:21:43 -0800707 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700708
709 return arg_list
710
711 @staticmethod
712 def agResize(testGen, opName, shapeList, dtype):
713 arg_list = []
714
715 ifm_shape = shapeList[0]
716
717 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
718
719 # Exclude illegal {mode, type} configurations. Pick legal output types
720 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800721 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700722 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800723 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700724 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800725 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700726 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800727 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800728 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800729 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700730 else:
731 continue
732
733 for outputDType in outputDTypeList:
734 for perm in range(testGen.args.num_rand_permutations):
735
736 # Randomly generate legal output dimensions and shift
737 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800738 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800739 in_center_h = (ifm_shape[1] - 1) / 2.0
740 in_center_w = (ifm_shape[2] - 1) / 2.0
741 out_center_h = (output_dims[0] - 1) / 2.0
742 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
Kevin Cheng77d0f762020-11-24 10:26:32 -0800744 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
745 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
746 fp_offset_y = in_center_h - fp_stride_y * out_center_h
747 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700748
Kevin Cheng77d0f762020-11-24 10:26:32 -0800749 if outputDType == DType.FLOAT:
750 shift = 0
751 stride = [0, 0]
752 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800753 stride_fp = [fp_stride_y, fp_stride_x]
754 offset_fp = [fp_offset_y, fp_offset_x]
755 arg_list.append(
756 (
757 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
758 m,
759 output_dims[0],
760 output_dims[1],
761 testGen.typeStr(outputDType),
762 stride_fp[0],
763 stride_fp[1],
764 offset_fp[0],
765 offset_fp[1],
766 ),
767 [
768 m,
769 stride,
770 offset,
771 shift,
772 stride_fp,
773 offset_fp,
774 output_dims,
775 dtype,
776 outputDType,
777 ],
778 )
779 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800780 else:
781 shift = 11
782 unit = float(1 << shift)
783 stride_y = int(round(fp_stride_y * unit))
784 stride_x = int(round(fp_stride_x * unit))
785 offset_y = int(round(fp_offset_y * unit))
786 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700787
Kevin Cheng550ccc52021-03-03 11:21:43 -0800788 while (
789 stride_y >= 32768
790 or stride_x >= 32768
791 or offset_y >= 32768
792 or offset_x >= 32768
793 or offset_y < -32768
794 or offset_x < -32768
795 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800796 shift = shift - 1
797 unit = float(1 << shift)
798 stride_y = int(round(fp_stride_y * unit))
799 stride_x = int(round(fp_stride_x * unit))
800 offset_y = int(round(fp_offset_y * unit))
801 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700802
Kevin Cheng550ccc52021-03-03 11:21:43 -0800803 stride = [stride_y, stride_x]
804 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800805
806 stride_fp = [0.0, 0.0]
807 offset_fp = [0.0, 0.0]
808
Kevin Cheng550ccc52021-03-03 11:21:43 -0800809 arg_list.append(
810 (
811 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
812 m,
813 shift,
814 output_dims[0],
815 output_dims[1],
816 testGen.typeStr(outputDType),
817 stride[0],
818 stride[1],
819 offset[0],
820 offset[1],
821 ),
822 [
823 m,
824 stride,
825 offset,
826 shift,
827 stride_fp,
828 offset_fp,
829 output_dims,
830 dtype,
831 outputDType,
832 ],
833 )
834 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700835
836 return arg_list
837
838 def agCondIf(testGen, opName, shapeList, dtype):
839 # CondIf generates the condition values here.
840 # Convert to tensors in the build function, along with the
841 # then and else blocks
842 arg_list = []
843
844 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800845 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700846
847 return arg_list
848
849 def agWhileLoop(testGen, opName, shapeList, dtype):
850 # While loop: 0 iterations, 1, more than 1
851 arg_list = []
852
853 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800854 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700855
856 return arg_list
857
Kevin Cheng550ccc52021-03-03 11:21:43 -0800858
Eric Kunzee5e26762020-10-13 16:11:07 -0700859class TosaTestGen:
860 def __init__(self, args):
861 self.args = args
862 self.basePath = args.output_dir
863 self.random_seed = args.random_seed
864 self.ser = None
865 self.rng = np.random.default_rng(self.random_seed)
866 self.createDynamicOpLists()
867 self.initOpListDefaults()
868 self.quantGen = TosaQuantGen()
869 # Force makeShape to do a specific starting shape
870 self.targetted_shape = None
871
872 def createSerializer(self, opName, testPath):
873 self.testPath = os.path.join(opName, testPath)
874
875 fullPath = os.path.join(self.basePath, self.testPath)
876 os.makedirs(fullPath, exist_ok=True)
877 self.ser = ts.TosaSerializer(fullPath)
878
879 def getSerializer(self):
880 return self.ser
881
882 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800883 with open(
884 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
885 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700886 fd.write(self.ser.serialize())
887
Kevin Cheng550ccc52021-03-03 11:21:43 -0800888 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
889 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700890
891 def getRandTensor(self, shape, dtype):
892 RAND_SHIFT_FACTOR = 0.5
893 RAND_SCALE_FACTOR = 4.0
894
895 if dtype == DType.BOOL:
896 np_dt = np.bool
897 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700898 elif dtype == DType.INT4:
899 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
900 elif dtype == DType.INT8:
901 return np.int32(self.rng.integers(low=-127, high=128, size=shape))
902 elif dtype == DType.INT16:
903 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
904 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800905 return np.int32(
906 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
907 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700908 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800909 return np.int64(
910 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
911 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700912 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800913 return np.float32(
914 self.rng.random(size=shape) - RAND_SHIFT_FACTOR * RAND_SCALE_FACTOR
915 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700916 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800917 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700918
Kevin Cheng989cb052021-04-28 16:29:44 -0700919 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700920 placeholders = []
921
Kevin Cheng989cb052021-04-28 16:29:44 -0700922 assert len(shape_list) == len(dtype_list)
923
924 for idx, shape in enumerate(shape_list):
925 arr = self.getRandTensor(shape, dtype_list[idx])
926 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700927
928 return placeholders
929
Kevin Cheng989cb052021-04-28 16:29:44 -0700930 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700931 consts = []
932
Kevin Cheng989cb052021-04-28 16:29:44 -0700933 assert len(shape_list) == len(dtype_list)
934
935 for idx, shape in enumerate(shape_list):
936 arr = self.getRandTensor(shape, dtype_list[idx])
937 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700938
939 return consts
940
941 def makeShape(self, rank):
942 if self.targetted_shape:
943 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800944 return np.int32(
945 self.rng.integers(
946 low=self.args.tensor_shape_range[0],
947 high=self.args.tensor_shape_range[1],
948 size=rank,
949 )
950 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700951
952 def setTargetShape(self, shape):
953 self.targetted_shape = shape
954
955 def randInt(self, low=0, high=256):
956 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
957
958 def getRandNumberDType(self, dtype):
959 if dtype == DType.FLOAT:
960 return self.rng.random()
961 elif dtype == DType.BOOL:
962 return self.rng.choice([False, True])
963 elif dtype == DType.INT4:
964 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700965 elif dtype == DType.INT8:
966 low, high = (-127, 128)
967 elif dtype == DType.INT16:
968 low, high = (-32768, 32768)
969 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800970 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700971 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800972 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700973 # Special size
974 return np.int64(self.rng.integers(low, high, size=1))[0]
975 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800976 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700977
978 return np.int32(self.rng.integers(low, high, size=1))[0]
979
980 def shapeStr(self, shape):
981
982 sStr = []
983 # Convert to strings
984 for i in shape:
985 sStr.append(str(i))
986
Kevin Cheng550ccc52021-03-03 11:21:43 -0800987 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700988
989 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700990 if isinstance(t, list):
991 assert len(t) >= 2
992 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700993 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700994 if t == DType.BOOL:
995 return "b"
996 elif t == DType.INT4:
997 return "i4"
998 elif t == DType.INT8:
999 return "i8"
1000 elif t == DType.UINT8:
1001 return "u8"
1002 elif t == DType.INT16:
1003 return "i16"
1004 elif t == DType.INT32:
1005 return "i32"
1006 elif t == DType.INT48:
1007 return "i48"
1008 elif t == DType.FLOAT:
1009 return "float"
1010 else:
1011 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001012
1013 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001014 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001015 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001016 return 4
1017 elif t == DType.INT8:
1018 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001019 elif t == DType.UINT8:
1020 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001021 elif t == DType.INT16:
1022 return 16
1023 elif t == DType.INT32:
1024 return 32
1025 elif t == DType.INT48:
1026 return 48
1027 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001028 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001029
1030 # Argument generators
1031 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1032 # Where the string descriptor is used to generate the test name and
1033 # The build_fcn_arg_list is expanded and passed to the operator test
1034 # build function
1035
Kevin Cheng550ccc52021-03-03 11:21:43 -08001036 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 result_tens = OutputShaper.unaryOp(self.ser, a)
1038 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1039 return result_tens
1040
1041 def build_binary_broadcast(self, op, a, b):
1042 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1043 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1044 return result_tens
1045
1046 def build_binary_nonbroadcast(self, op, a, b):
1047 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1048 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1049 return result_tens
1050
Kevin Chengaee1fac2020-11-11 13:54:06 -08001051 def build_arithmetic_right_shift(self, op, a, b, round):
1052 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1053
1054 attr = ts.TosaSerializerAttribute()
1055 attr.ArithmeticRightShiftAttribute(round)
1056
1057 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1058 return result_tens
1059
1060 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001061 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1062
1063 # Special for multiply:
1064 # Force the result to INT32 for INT types
1065 if a.dtype != DType.FLOAT:
1066 result_tens.setDtype(DType.INT32)
1067
Kevin Chengaee1fac2020-11-11 13:54:06 -08001068 attr = ts.TosaSerializerAttribute()
1069 attr.MulAttribute(shift)
1070
1071 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001072 return result_tens
1073
1074 def build_table(self, op, a):
1075 # Constant size, random values
1076 table_arr = self.getRandTensor([513], DType.INT16)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001077 table_tens = self.ser.addConst(table_arr.shape, DType.INT16, table_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001078
1079 result_tens = OutputShaper.tableOp(self.ser, a, table_tens)
1080 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1081
1082 return result_tens
1083
1084 def build_select(self, op, cond, a, b):
1085
1086 # Replace the cond tensor with a boolean tensor since it probably
1087 # has the wrong dtype
Kevin Cheng989cb052021-04-28 16:29:44 -07001088 t = self.buildPlaceholderTensors([cond.shape], [DType.BOOL])
Eric Kunzee5e26762020-10-13 16:11:07 -07001089 cond = t[0]
1090
1091 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1092 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
1093
1094 return result_tens
1095
1096 def build_comparison(self, op, a, b):
1097 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1098 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1099 return result_tens
1100
1101 def build_argmax(self, op, a, axis):
1102 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1103
1104 attr = ts.TosaSerializerAttribute()
1105 attr.AxisAttribute(axis)
1106
1107 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1108 return result_tens
1109
Kevin Cheng550ccc52021-03-03 11:21:43 -08001110 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001111 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1112
1113 attr = ts.TosaSerializerAttribute()
1114 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001115
1116 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1117 return result_tens
1118
1119 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001120 assert len(padding) == 4
1121 result_tens = OutputShaper.conv2dOp(
1122 self.ser, ifm, filter, strides, padding, dilations
1123 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001124
1125 attr = ts.TosaSerializerAttribute()
1126 attr.Conv2dAttribute(padding, strides, dilations)
1127
Kevin Cheng550ccc52021-03-03 11:21:43 -08001128 self.ser.addOperator(
1129 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1130 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001131 return result_tens
1132
Kevin Cheng550ccc52021-03-03 11:21:43 -08001133 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001134 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001135 ):
1136 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001137 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1138
1139 attr = ts.TosaSerializerAttribute()
1140 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1141
Kevin Cheng550ccc52021-03-03 11:21:43 -08001142 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001143 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001144 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001145 return result_tens
1146
Kevin Cheng550ccc52021-03-03 11:21:43 -08001147 def build_depthwise_conv2d(
1148 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1149 ):
1150 result_tens = OutputShaper.depthwiseConv2dOp(
1151 self.ser, ifm, filter, strides, padding, dilations
1152 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001153
1154 attr = ts.TosaSerializerAttribute()
1155 attr.Conv2dAttribute(padding, strides, dilations)
1156
Kevin Cheng550ccc52021-03-03 11:21:43 -08001157 self.ser.addOperator(
1158 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1159 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001160 return result_tens
1161
1162 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1163 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1164
Kevin Cheng550ccc52021-03-03 11:21:43 -08001165 self.ser.addOperator(
1166 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1167 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001168 return result_tens
1169
1170 def build_matmul(self, op, a, b, qinfo):
1171 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1172 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1173 return result_tens
1174
1175 def build_reduce(self, op, a, axis):
1176 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1177
1178 attr = ts.TosaSerializerAttribute()
1179 attr.AxisAttribute(axis)
1180
1181 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1182 return result_tens
1183
1184 def build_clamp(self, op, a):
1185 result_tens = OutputShaper.unaryOp(self.ser, a)
1186
1187 attr = ts.TosaSerializerAttribute()
1188
1189 # Get two random ints
1190 v = [self.randInt(), self.randInt()]
1191
1192 if a.dtype == DType.FLOAT:
1193 attr.ClampAttribute(0, 0, min(v), max(v))
1194 else:
1195 attr.ClampAttribute(min(v), max(v), 0, 0)
1196
1197 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1198 return result_tens
1199
1200 def build_leaky_relu(self, op, a):
1201 result_tens = OutputShaper.unaryOp(self.ser, a)
1202 attr = ts.TosaSerializerAttribute()
1203
1204 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1205
1206 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1207 return result_tens
1208
1209 # Needs an additional type/input
1210 def build_prelu(self, op, a):
1211 result_tens = OutputShaper.unaryOp(self.ser, a)
1212
1213 self.ser.addOperator(op, [a.name], [result_tens.name])
1214 return result_tens
1215
1216 def build_relun(self, op, a):
1217 result_tens = OutputShaper.unaryOp(self.ser, a)
1218
1219 attr = ts.TosaSerializerAttribute()
1220
1221 if a.dtype == DType.FLOAT:
1222 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1223 else:
1224 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1225
1226 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1227 return result_tens
1228
1229 def build_sigmoid(self, op, a):
1230 result_tens = OutputShaper.unaryOp(self.ser, a)
1231 self.ser.addOperator(op, [a.name], [result_tens.name])
1232 return result_tens
1233
1234 def build_tanh(self, op, a):
1235 result_tens = OutputShaper.unaryOp(self.ser, a)
1236 self.ser.addOperator(op, [a.name], [result_tens.name])
1237 return result_tens
1238
1239 def build_concat(self, op, a, b, axis):
1240 result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
1241
1242 attr = ts.TosaSerializerAttribute()
1243 attr.AxisAttribute(axis)
1244
1245 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1246
1247 def build_pad(self, op, a, padding, qinfo):
1248 result_tens = OutputShaper.padOp(self.ser, a, padding)
1249
1250 # Need to turn the padding array into a TOSA tensor here.
1251 # This is one of the few tensor operands that does not get
1252 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001253 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001254
Kevin Cheng550ccc52021-03-03 11:21:43 -08001255 self.ser.addOperator(
1256 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1257 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001258
1259 def build_reshape(self, op, a, newShape):
1260 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1261
1262 attr = ts.TosaSerializerAttribute()
1263 attr.ReshapeAttribute(newShape)
1264
1265 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1266 return result_tens
1267
1268 def build_reverse(self, op, a, axis):
1269 result_tens = OutputShaper.unaryOp(self.ser, a)
1270
1271 attr = ts.TosaSerializerAttribute()
1272 attr.AxisAttribute(axis)
1273
1274 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1275 return result_tens
1276
1277 def build_transpose(self, op, a, perms):
1278 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1279
Kevin Cheng550ccc52021-03-03 11:21:43 -08001280 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001281
1282 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1283 return result_tens
1284
1285 def build_slice(self, op, a, begin, size):
1286 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1287
1288 attr = ts.TosaSerializerAttribute()
1289 attr.SliceAttribute(begin, size)
1290
1291 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1292 return result_tens
1293
1294 def build_tile(self, op, a, multiples):
1295 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1296
1297 attr = ts.TosaSerializerAttribute()
1298 attr.TileAttribute(multiples)
1299
1300 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1301 return result_tens
1302
Kevin Cheng77d0f762020-11-24 10:26:32 -08001303 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001304
1305 # Create a new indicies tensor
1306 # here with data that doesn't exceed the dimensions of the values tensor
1307
Kevin Cheng550ccc52021-03-03 11:21:43 -08001308 K = values.shape[1] # K
1309 W = self.randInt(
1310 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1311 ) # W
1312 indicies_arr = np.int32(
1313 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1314 ) # (N, W)
1315 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001316
Kevin Cheng77d0f762020-11-24 10:26:32 -08001317 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
Kevin Cheng77d0f762020-11-24 10:26:32 -08001319 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001320
1321 return result_tens
1322
Kevin Cheng77d0f762020-11-24 10:26:32 -08001323 def build_scatter(self, op, values_in, input):
1324
1325 # Create a new indicies tensor
1326 # here with data that doesn't exceed the dimensions of the values_in tensor
1327
Kevin Cheng550ccc52021-03-03 11:21:43 -08001328 K = values_in.shape[1] # K
1329 W = input.shape[1] # W
1330 indicies_arr = np.int32(
1331 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1332 ) # (N, W)
1333 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001334
1335 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1336
Kevin Cheng550ccc52021-03-03 11:21:43 -08001337 self.ser.addOperator(
1338 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1339 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001340
1341 return result_tens
1342
Kevin Cheng550ccc52021-03-03 11:21:43 -08001343 def build_resize(
1344 self,
1345 op,
1346 input,
1347 mode,
1348 stride,
1349 offset,
1350 shift,
1351 stride_fp,
1352 offset_fp,
1353 output_dims,
1354 input_dtype,
1355 output_dtype,
1356 ):
1357 result_tens = OutputShaper.resizeOp(
1358 self.ser,
1359 input,
1360 mode,
1361 stride,
1362 offset,
1363 shift,
1364 stride_fp,
1365 offset_fp,
1366 output_dims,
1367 input_dtype,
1368 output_dtype,
1369 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
1371 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001372
Kevin Cheng550ccc52021-03-03 11:21:43 -08001373 attr.ResizeAttribute(
1374 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1375 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001376
1377 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1378 return result_tens
1379
1380 def build_identityn(self, op, val, val2):
1381
Kevin Cheng550ccc52021-03-03 11:21:43 -08001382 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001383 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001384 self.ser.addOperator(
1385 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1386 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001387 return result_tens
1388
1389 def build_placeholder(self, op, val):
1390 # Add an identity op to avoid warning in the reference model
1391 return self.build_unary(Op.IDENTITY, val)
1392
1393 # Type Conversion
1394 def build_cast(self, op, val, out_dtype):
1395 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1396 self.ser.addOperator(op, [val.name], [result_tens.name])
1397 return result_tens
1398
1399 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1400 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1401
1402 if per_channel:
1403 nc = val.shape[-1]
1404 else:
1405 nc = 1
1406
1407 in_type_width = self.typeWidth(val.dtype)
1408 out_type_width = self.typeWidth(out_dtype)
1409
Kevin Cheng3a478572021-01-22 17:21:02 -08001410 if val.dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001411 input_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001412 in_type_width = in_type_width + 1
1413 else:
1414 input_zp = 0
1415
Kevin Cheng3a478572021-01-22 17:21:02 -08001416 if out_dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001417 output_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001418 out_type_width = out_type_width + 1
1419 else:
1420 output_zp = 0
1421
1422 # Calculate scale based on:
1423 # scale = a *(2^output_width)/(2^input_width))
1424
1425 a = np.float32(self.rng.random(size=[nc]))
1426 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1427
1428 if scale32:
1429 pass
1430 # Cap the scaling at 2^15 - 1 for scale16
1431 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1432 else:
1433 # Cap the scaling at 2^15 - 1 for scale16
1434 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1435
Kevin Cheng550ccc52021-03-03 11:21:43 -08001436 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001437
1438 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1439 shift_arr = np.int32(np.zeros(shape=[nc]))
1440
1441 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001442 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1443 scale_arr[i], scale32
1444 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001445 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001446 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
Kevin Cheng550ccc52021-03-03 11:21:43 -08001448 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001449
1450 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001451 attr.RescaleAttribute(
1452 input_zp,
1453 output_zp,
1454 multiplier_arr,
1455 shift_arr,
1456 scale32,
1457 double_round,
1458 per_channel,
1459 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001460
1461 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1462 return result_tens
1463
1464 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1465 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1466 # (except for the generated shap) and the condition. Build Then/Else blocks
1467 # and fill them with const nodes for the body.
1468
1469 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001470 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001471
1472 # Make then/else tensors
1473 out_shape = then_tens.shape
1474 then_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1475 else_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1476
1477 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001478 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001479
1480 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001481 then_block = "THEN_BLOCK"
1482 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001483 attr = ts.TosaSerializerAttribute()
1484 attr.CondIfAttribute(then_block, else_block)
1485
1486 # Finally, build the op and the two blocks
1487 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1488
1489 self.ser.startBasicBlock(then_block)
1490 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001491 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001492 self.ser.addOutputTensor(then_tens)
1493
1494 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001495 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001496 self.ser.addOutputTensor(else_tens)
1497
1498 return result_tens
1499
1500 def build_cond_if_binary(self, op, a, b, cond):
1501 # For cond_if with a binary op in the then/else blocks, take a and b and
1502 # alternately add or subtract them based on the condition
1503
1504 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001505 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001506
Kevin Cheng550ccc52021-03-03 11:21:43 -08001507 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001508 self.ser.currBasicBlock.addOutput(result_tens.name)
1509
1510 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001511 then_block = "THEN_BLOCK"
1512 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001513 attr = ts.TosaSerializerAttribute()
1514 attr.CondIfAttribute(then_block, else_block)
1515
1516 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001517 self.ser.addOperator(
1518 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1519 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001520
1521 self.ser.startBasicBlock(then_block)
1522 self.ser.addInputTensor(a)
1523 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001524 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001525 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1526
1527 self.ser.startBasicBlock(else_block)
1528 self.ser.addInputTensor(a)
1529 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001530 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001531 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1532
1533 return result_tens
1534
1535 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001536 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001537
Kevin Cheng550ccc52021-03-03 11:21:43 -08001538 cond_block = "COND_BLOCK"
1539 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001540
1541 attr = ts.TosaSerializerAttribute()
1542 attr.WhileLoopAttribute(cond_block, body_block)
1543
1544 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001545 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001546 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001547 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001548
1549 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001550 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1551 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1552 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001553
1554 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001555 self.ser.addOperator(
1556 op,
1557 [iter.name, a.name, acc.name],
1558 [iter_out.name, a_out.name, acc_out.name],
1559 attr,
1560 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001561
1562 # COND block (input: iter, output: cond_tens )
1563 self.ser.startBasicBlock(cond_block)
1564 self.ser.addInputTensor(iter)
1565 self.ser.addInputTensor(a)
1566 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001567 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1568 cond_tens = self.ser.addOutput([], DType.BOOL)
1569 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001570
1571 # BODY block (input: a, acc, iter, output: a, acc, iter)
1572 # Note that local intermediate tensors need to be declared here for the outputs
1573 self.ser.startBasicBlock(body_block)
1574 self.ser.addInputTensor(iter)
1575 self.ser.addInputTensor(a)
1576 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001577 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1578 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1579 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001580 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1581 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1582 self.ser.addOutputTensor(iter_body_out)
1583 self.ser.addOutputTensor(a)
1584 self.ser.addOutputTensor(acc_body_out)
1585
1586 return acc_out
1587
Kevin Cheng550ccc52021-03-03 11:21:43 -08001588 def genOpTestList(
1589 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1590 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001591
1592 try:
1593 op = self.TOSA_OP_LIST[opName]
1594 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001595 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001596
1597 # Initialize a new random number generator
1598 self.rng = np.random.default_rng(self.random_seed)
1599
Kevin Cheng550ccc52021-03-03 11:21:43 -08001600 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001601
1602 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001603 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001604
1605 # Test list consists of a tuple of:
1606 # (opName, testNameStr, dtype, shapeList, argumentsList)
1607 testList = []
1608
1609 if not shapeFilter:
1610 shapeFilter = [None]
1611
1612 for r in range(rmin, rmax + 1):
1613
1614 # Filter out the rank?
1615 if rankFilter is not None and r not in rankFilter:
1616 continue
1617
Kevin Cheng550ccc52021-03-03 11:21:43 -08001618 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001619
1620 # Filter tests based on dtype?
1621 if dtypeFilter is not None:
1622 if t not in dtypeFilter:
1623 continue
1624
1625 # Create the placeholder and const tensors
1626 for shape in shapeFilter:
1627 # A None shape chooses a random shape of a given rank
1628
1629 # Filter out by rank
1630 if shape is not None and len(shape) != r:
1631 continue
1632
1633 self.setTargetShape(shape)
1634 shapeList = tgen_fcn(self, op, r)
1635
1636 shapeStr = self.shapeStr(shapeList[0])
1637 typeStr = self.typeStr(t)
1638
1639 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1640 argList = []
1641 if agen_fcn:
1642 argList = agen_fcn(self, opName, shapeList, t)
1643 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001644 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001645
1646 for argStr, args in argList:
1647 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001648 testStr = "{}_{}_{}_{}".format(
1649 opName, shapeStr, typeStr, argStr
1650 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001651 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001652 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001653
1654 testList.append((opName, testStr, t, shapeList, args))
1655
1656 return testList
1657
Kevin Cheng989cb052021-04-28 16:29:44 -07001658 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001659 try:
1660 op = self.TOSA_OP_LIST[opName]
1661 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001662 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001663
1664 # Create a serializer
1665 self.createSerializer(opName, testStr)
1666
Kevin Cheng550ccc52021-03-03 11:21:43 -08001667 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1668 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001669 num_operands = pCount + cCount
1670
1671 if isinstance(dtype_or_dtypeList, list):
1672 dtypeList = dtype_or_dtypeList
1673 else:
1674 dtypeList = [dtype_or_dtypeList] * (num_operands)
1675
1676 assert (
1677 len(shapeList) == num_operands
1678 ), "shapeList length {} must match number of operands {}".format(
1679 len(shapeList), num_operands
1680 )
1681 assert (
1682 len(dtypeList) == num_operands
1683 ), "dtypeList length {} must match number of operands {}".format(
1684 len(dtypeList), num_operands
1685 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001686
1687 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001688 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001689 except KeyError:
1690 qgen = None
1691
1692 # Build the random tensor operands and the test
1693 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001694
1695 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001696 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1697 assert (
1698 pCount == 2 and cCount == 0
1699 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001700
1701 placeholders = []
1702 for idx, shape in enumerate(shapeList[:]):
1703 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001704 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001705 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001706 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001707 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001708 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001709 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1710 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001711 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001712 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001713 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07001714 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001715
1716 tens.extend(placeholders)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001717 elif op["op"] == Op.DIV:
1718 assert (
1719 pCount == 2 and cCount == 0
1720 ), "Op.Div must have 2 placeholders, 0 consts"
1721
1722 placeholders = []
1723
1724 # Two invalid cases for Op.DIV:
1725 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07001726 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001727 while True:
1728 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
1729 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
1730
1731 if (divisor_arr == 0).any():
1732 continue
1733
Kevin Cheng47315e12021-05-13 17:41:28 -07001734 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001735 continue
1736
1737 break
1738
1739 placeholders.append(
1740 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1741 )
1742 placeholders.append(
1743 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1744 )
1745
1746 tens.extend(placeholders)
1747 elif op["op"] == Op.MUL:
1748 assert (
1749 pCount == 2 and cCount == 0
1750 ), "Op.MUL must have 2 placeholders, 0 consts"
1751
1752 if dtypeList[0] == DType.FLOAT:
1753 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
1754 else:
1755 placeholders = []
1756
1757 # Make sure multiply result in int32 range
1758 shift = testArgs[0]
1759 if dtypeList[0] == DType.INT8:
1760 num_bits = 8
1761 elif dtypeList[0] == DType.INT16:
1762 num_bits = 16
1763 elif dtypeList[0] == DType.INT32:
1764 num_bits = 32
1765 else:
1766 raise Exception("OpMul: invalid input dtype")
1767
1768 for idx, shape in enumerate(shapeList[:]):
1769 low = -(2 ** (num_bits - 1))
1770 high = (2 ** (num_bits - 1)) - 1
1771
1772 a_arr = np.int32(
1773 self.rng.integers(low=low, high=high, size=shapeList[0])
1774 )
1775 b_arr = np.int32(
1776 self.rng.integers(low=low, high=high, size=shapeList[1])
1777 )
1778
1779 i = 0
1780 while True:
1781
1782 a_arr_64 = a_arr.astype(np.int64)
1783 b_arr_64 = b_arr.astype(np.int64)
1784
1785 if shift > 0:
1786 rounding = 1 << (shift - 1)
1787 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
1788 else:
1789 result_arr = a_arr_64 * b_arr_64
1790
1791 if (result_arr > -(2 ** 31)).all() and (
1792 result_arr <= ((2 ** 31) - 1)
1793 ).all():
1794 break
1795
1796 i = i + 1
1797 a_arr = a_arr // 2
1798 b_arr = b_arr // 2
1799
1800 placeholders.append(
1801 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1802 )
1803 placeholders.append(
1804 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1805 )
1806
1807 tens.extend(placeholders)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001808 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001809 tens.extend(
1810 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1811 )
1812 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001813
1814 if qgen is not None:
Kevin Cheng989cb052021-04-28 16:29:44 -07001815 qinfo = qgen(self, op, dtypeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001816 else:
1817 qinfo = None
1818
1819 try:
1820 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001821 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001822 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001823 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001824 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001825 print(
1826 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1827 build_fcn, tens, testArgs
1828 )
1829 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001830 raise e
1831
1832 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001833 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001834
1835 def createDynamicOpLists(self):
1836
1837 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001838 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001839
1840 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001841 testName = "conv2d_{}x{}".format(k[0], k[1])
1842 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1843 self.TOSA_OP_LIST[testName]["filter"] = k
1844 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001845
Kevin Cheng550ccc52021-03-03 11:21:43 -08001846 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1847 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1848 "depthwise_conv2d_TEMPLATE"
1849 ].copy()
1850 self.TOSA_OP_LIST[testName]["filter"] = k
1851 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001852
Kevin Cheng550ccc52021-03-03 11:21:43 -08001853 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1854 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1855 "transpose_conv2d_TEMPLATE"
1856 ].copy()
1857 self.TOSA_OP_LIST[testName]["filter"] = k
1858 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001859
1860 # Delete any templates after having created any dynamic ops
1861 # This is a two-pass operation because it's bad practice to delete
1862 # keys from dictionaries while iterating
1863 keyList = []
1864 for k in self.TOSA_OP_LIST:
1865 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001866 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001867 keyList.append(k)
1868 continue
1869 except KeyError:
1870 pass
1871
1872 for k in keyList:
1873 del self.TOSA_OP_LIST[k]
1874
1875 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001876 """Fill in default fields for ops if they aren't already specified.
1877 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001878 for op in self.TOSA_OP_LIST:
1879
1880 # Required fields
1881 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001882 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001883 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001884 raise Exception(
1885 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
1886 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001887
1888 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001889 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001890 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001891 raise Exception(
1892 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
1893 op
1894 )
1895 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001896
1897 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001898 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001899 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001900 raise Exception(
1901 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
1902 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001903
1904 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001905 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001906 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001907 raise Exception(
1908 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
1909 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001910
1911 # Put in default rank range, if missing
1912 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001913 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001914 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001915 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07001916
1917 # Tensor operator list
1918 # 'op': op name
1919 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08001920 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
1921 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001922 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
1923 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08001924 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001925
Kevin Cheng550ccc52021-03-03 11:21:43 -08001926 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
1927 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07001928
Kevin Cheng550ccc52021-03-03 11:21:43 -08001929 TYPE_BOOL = [DType.BOOL]
1930 TYPE_FI32 = [DType.FLOAT, DType.INT32]
1931 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
1932 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07001933
Kevin Cheng550ccc52021-03-03 11:21:43 -08001934 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001935
Kevin Cheng989cb052021-04-28 16:29:44 -07001936 TYPE_CONV2D = [
1937 [DType.INT8, DType.INT8, DType.INT32],
1938 [DType.INT16, DType.INT8, DType.INT48],
1939 DType.FLOAT,
1940 ]
1941
Eric Kunzee5e26762020-10-13 16:11:07 -07001942 DEFAULT_RANK_RANGE = (1, 4)
1943
1944 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08001945 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08001946 "argmax": {
1947 "op": Op.ARGMAX,
1948 "operands": (1, 0),
1949 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
1950 "types": TYPE_NARROW_INT_FP,
1951 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001952 "avg_pool2d": {
1953 "op": Op.AVG_POOL2D,
1954 "operands": (1, 0),
1955 "rank": (4, 4),
1956 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
1957 "qgen": TosaQuantGen.qgUnary,
1958 "types": TYPE_NARROW_INT_FP,
1959 },
Eric Kunzee5e26762020-10-13 16:11:07 -07001960 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08001961 "conv2d_TEMPLATE": {
1962 "op": Op.CONV2D,
1963 "operands": (1, 2),
1964 "rank": (4, 4),
1965 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
1966 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07001967 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001968 "template": True,
1969 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001970 # Conv3d TBD
Eric Kunzee5e26762020-10-13 16:11:07 -07001971 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08001972 "depthwise_conv2d_TEMPLATE": {
1973 "op": Op.DEPTHWISE_CONV2D,
1974 "operands": (1, 2),
1975 "filter": [1, 1],
1976 "rank": (4, 4),
1977 "build_fcn": (
1978 build_depthwise_conv2d,
1979 TosaTensorGen.tgDepthwiseConv2D,
1980 TosaArgGen.agConv2D,
1981 ),
1982 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07001983 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001984 "template": True,
1985 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001986 "fully_connected": {
1987 "op": Op.FULLY_CONNECTED,
1988 "operands": (1, 2),
1989 "rank": (2, 2),
1990 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
1991 "qgen": TosaQuantGen.qgConv,
1992 "types": TYPE_CONV2D,
1993 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001994 "matmul": {
1995 "op": Op.MATMUL,
1996 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07001997 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08001998 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
1999 "qgen": TosaQuantGen.qgMatmul,
2000 "types": TYPE_NARROW_INT_FP,
2001 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002002 "max_pool2d": {
2003 "op": Op.MAX_POOL2D,
2004 "operands": (1, 0),
2005 "rank": (4, 4),
2006 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2007 "types": TYPE_NARROW_INT_FP,
2008 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002009 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002010 "transpose_conv2d_TEMPLATE": {
2011 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002012 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002013 "rank": (4, 4),
2014 "build_fcn": (
2015 build_transpose_conv2d,
2016 TosaTensorGen.tgTransposeConv2D,
2017 TosaArgGen.agTransposeConv2D,
2018 ),
2019 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002020 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002021 "template": True,
2022 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002023 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002024 "clamp": {
2025 "op": Op.CLAMP,
2026 "operands": (1, 0),
2027 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2028 "types": TYPE_NARROW_INT_FP,
2029 },
2030 "relun": {
2031 "op": Op.RELUN,
2032 "operands": (1, 0),
2033 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2034 "types": TYPE_FI32,
2035 },
2036 "sigmoid": {
2037 "op": Op.SIGMOID,
2038 "operands": (1, 0),
2039 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2040 "types": TYPE_FP,
2041 },
2042 "tanh": {
2043 "op": Op.TANH,
2044 "operands": (1, 0),
2045 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2046 "types": TYPE_FP,
2047 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002048 # Elementwise Binary Operators
2049 "add": {
2050 "op": Op.ADD,
2051 "operands": (2, 0),
2052 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2053 "types": TYPE_FI32,
2054 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002055 "arithmetic_right_shift": {
2056 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2057 "operands": (2, 0),
2058 "build_fcn": (
2059 build_arithmetic_right_shift,
2060 TosaTensorGen.tgBroadcastFuzz,
2061 TosaArgGen.agArithmeticRightShift,
2062 ),
2063 "types": TYPE_INT,
2064 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002065 "bitwise_and": {
2066 "op": Op.BITWISE_AND,
2067 "operands": (2, 0),
2068 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2069 "types": TYPE_INT,
2070 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002071 "bitwise_or": {
2072 "op": Op.BITWISE_OR,
2073 "operands": (2, 0),
2074 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2075 "types": TYPE_INT,
2076 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002077 "bitwise_xor": {
2078 "op": Op.BITWISE_XOR,
2079 "operands": (2, 0),
2080 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2081 "types": TYPE_INT,
2082 },
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002083 "div": {
2084 "op": Op.DIV,
2085 "operands": (2, 0),
2086 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2087 "types": [DType.INT32],
2088 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002089 "logical_and": {
2090 "op": Op.LOGICAL_AND,
2091 "operands": (2, 0),
2092 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2093 "types": TYPE_BOOL,
2094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002095 "logical_left_shift": {
2096 "op": Op.LOGICAL_LEFT_SHIFT,
2097 "operands": (2, 0),
2098 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2099 "types": TYPE_INT,
2100 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002101 "logical_right_shift": {
2102 "op": Op.LOGICAL_RIGHT_SHIFT,
2103 "operands": (2, 0),
2104 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2105 "types": TYPE_INT,
2106 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002107 "logical_or": {
2108 "op": Op.LOGICAL_OR,
2109 "operands": (2, 0),
2110 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2111 "types": TYPE_BOOL,
2112 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002113 "logical_xor": {
2114 "op": Op.LOGICAL_XOR,
2115 "operands": (2, 0),
2116 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2117 "types": TYPE_BOOL,
2118 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002119 "maximum": {
2120 "op": Op.MAXIMUM,
2121 "operands": (2, 0),
2122 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2123 "types": TYPE_FI32,
2124 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002125 "minimum": {
2126 "op": Op.MINIMUM,
2127 "operands": (2, 0),
2128 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2129 "types": TYPE_FI32,
2130 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002131 "mul": {
2132 "op": Op.MUL,
2133 "operands": (2, 0),
2134 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2135 "types": TYPE_INT_FP,
2136 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002137 "pow": {
2138 "op": Op.POW,
2139 "operands": (2, 0),
2140 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2141 "types": TYPE_FP,
2142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002143 "sub": {
2144 "op": Op.SUB,
2145 "operands": (2, 0),
2146 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2147 "types": TYPE_FI32,
2148 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002149 "table": {
2150 "op": Op.TABLE,
2151 # Use the automatic generation functions to create the input array
2152 # but create the table tensor in the build function, as it may be
2153 # a different type from the input
2154 "operands": (1, 0),
2155 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
2156 "types": [DType.INT16],
2157 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002158 # Elementwise Unary operators
2159 "abs": {
2160 "op": Op.ABS,
2161 "operands": (1, 0),
2162 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2163 "types": TYPE_FI32,
2164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002165 "bitwise_not": {
2166 "op": Op.BITWISE_NOT,
2167 "operands": (1, 0),
2168 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2169 "types": TYPE_INT,
2170 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002171 "ceil": {
2172 "op": Op.CEIL,
2173 "operands": (1, 0),
2174 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2175 "types": TYPE_FP,
2176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002177 "clz": {
2178 "op": Op.CLZ,
2179 "operands": (1, 0),
2180 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2181 "types": [DType.INT32],
2182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002183 "exp": {
2184 "op": Op.EXP,
2185 "operands": (1, 0),
2186 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2187 "types": TYPE_FP,
2188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002189 "floor": {
2190 "op": Op.FLOOR,
2191 "operands": (1, 0),
2192 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2193 "types": TYPE_FP,
2194 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002195 "log": {
2196 "op": Op.LOG,
2197 "operands": (1, 0),
2198 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2199 "types": TYPE_FP,
2200 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002201 "logical_not": {
2202 "op": Op.LOGICAL_NOT,
2203 "operands": (1, 0),
2204 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2205 "types": TYPE_BOOL,
2206 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002207 "negate": {
2208 "op": Op.NEGATE,
2209 "operands": (1, 0),
2210 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2211 "qgen": TosaQuantGen.qgUnary,
2212 "types": TYPE_INT_FP,
2213 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002214 "reciprocal": {
2215 "op": Op.RECIPROCAL,
2216 "operands": (1, 0),
2217 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2218 "types": TYPE_FP,
2219 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002220 "rsqrt": {
2221 "op": Op.RSQRT,
2222 "operands": (1, 0),
2223 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2224 "types": TYPE_FP,
2225 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002226 # Elementwise Ternary operators
2227 "select": {
2228 "op": Op.SELECT,
2229 "operands": (3, 0),
2230 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2231 "types": TYPE_FIB,
2232 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002233 # Comparison operators
2234 "equal": {
2235 "op": Op.EQUAL,
2236 "operands": (2, 0),
2237 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2238 "types": TYPE_FI32,
2239 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002240 "greater_equal": {
2241 "op": Op.GREATER_EQUAL,
2242 "operands": (2, 0),
2243 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2244 "types": TYPE_FI32,
2245 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002246 "greater": {
2247 "op": Op.GREATER,
2248 "operands": (2, 0),
2249 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2250 "types": TYPE_FI32,
2251 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002252 # Reduction operators
2253 "reduce_all": {
2254 "op": Op.REDUCE_ALL,
2255 "operands": (1, 0),
2256 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2257 "types": TYPE_BOOL,
2258 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002259 "reduce_any": {
2260 "op": Op.REDUCE_ANY,
2261 "operands": (1, 0),
2262 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2263 "types": TYPE_BOOL,
2264 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002265 "reduce_max": {
2266 "op": Op.REDUCE_MAX,
2267 "operands": (1, 0),
2268 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2269 "types": TYPE_INT_FP,
2270 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002271 "reduce_min": {
2272 "op": Op.REDUCE_MAX,
2273 "operands": (1, 0),
2274 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2275 "types": TYPE_INT_FP,
2276 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002277 "reduce_product": {
2278 "op": Op.REDUCE_PRODUCT,
2279 "operands": (1, 0),
2280 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2281 "types": TYPE_FP,
2282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002283 "reduce_sum": {
2284 "op": Op.REDUCE_SUM,
2285 "operands": (1, 0),
2286 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2287 "types": TYPE_FI32,
2288 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002289 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002290 "concat": {
2291 "op": Op.CONCAT,
2292 "operands": (2, 0),
2293 "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2294 "types": TYPE_FIB,
2295 },
2296 "pad": {
2297 "op": Op.PAD,
2298 "operands": (1, 0),
2299 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2300 "qgen": TosaQuantGen.qgPad,
2301 "types": TYPE_FIB,
2302 },
2303 "reshape": {
2304 "op": Op.RESHAPE,
2305 "operands": (1, 0),
2306 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2307 "types": TYPE_FIB,
2308 },
2309 "reverse": {
2310 "op": Op.REVERSE,
2311 "operands": (1, 0),
2312 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2313 "types": TYPE_FIB,
2314 },
2315 "slice": {
2316 "op": Op.SLICE,
2317 "operands": (1, 0),
2318 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2319 "types": TYPE_FIB,
2320 },
2321 "tile": {
2322 "op": Op.TILE,
2323 "operands": (1, 0),
2324 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2325 "types": TYPE_FIB,
2326 },
2327 "transpose": {
2328 "op": Op.TRANSPOSE,
2329 "operands": (1, 0),
2330 "rank": (2, 4), # Do not allow tranpose on rank=1
2331 "build_fcn": (
2332 build_transpose,
2333 TosaTensorGen.tgBasic,
2334 TosaArgGen.agTranspose,
2335 ),
2336 "types": TYPE_FIB,
2337 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002338 # Data nodes
2339 "const": {
2340 "op": Op.CONST,
2341 "operands": (1, 0),
2342 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2343 "types": TYPE_FIB,
2344 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002345 "identity": {
2346 "op": Op.IDENTITY,
2347 "operands": (1, 0),
2348 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2349 "types": TYPE_FIB,
2350 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002351 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002352 "gather": {
2353 "op": Op.GATHER,
2354 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2355 "operands": (1, 0),
2356 "rank": (3, 3),
2357 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2358 "types": TYPE_INT_FP,
2359 },
2360 "scatter": {
2361 "op": Op.SCATTER,
2362 # Only specify 'values_in' tensor here.
2363 #'indices' and 'input' are generated in op building stage
2364 "operands": (2, 0),
2365 "rank": (3, 3),
2366 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2367 "types": TYPE_INT_FP,
2368 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002369 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002370 "resize": {
2371 "op": Op.RESIZE,
2372 "operands": (1, 0),
2373 "rank": (4, 4),
2374 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2375 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2376 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002377 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002378 "cast": {
2379 "op": Op.CAST,
2380 "operands": (1, 0),
2381 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2382 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2383 },
2384 "rescale": {
2385 "op": Op.RESCALE,
2386 "operands": (1, 0),
2387 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
2388 "types": [DType.INT8, DType.INT16, DType.INT32, DType.INT48],
2389 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002390 # Custom
2391 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002392 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002393 # Two varients of cond_if, one that generates one of two constant tensors (no
2394 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2395 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002396 "cond_if_const": {
2397 "op": Op.COND_IF,
2398 "operands": (0, 2),
2399 "build_fcn": (
2400 build_cond_if_const,
2401 TosaTensorGen.tgBasic,
2402 TosaArgGen.agCondIf,
2403 ),
2404 "types": [DType.BOOL],
2405 },
2406 "cond_if_binary": {
2407 "op": Op.COND_IF,
2408 "operands": (2, 0),
2409 "build_fcn": (
2410 build_cond_if_binary,
2411 TosaTensorGen.tgBasic,
2412 TosaArgGen.agCondIf,
2413 ),
2414 "types": TYPE_FI32,
2415 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002416 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002417 "while_loop": {
2418 "op": Op.WHILE_LOOP,
2419 "operands": (0, 1),
2420 "build_fcn": (
2421 build_while_loop,
2422 TosaTensorGen.tgBasic,
2423 TosaArgGen.agWhileLoop,
2424 ),
2425 "types": [DType.INT32],
2426 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002427 }
2428
Kevin Cheng550ccc52021-03-03 11:21:43 -08002429
Eric Kunzee5e26762020-10-13 16:11:07 -07002430class OutputShaper:
2431 # Methods in this class compute the expected output shape and datatype
2432 # for common classes of operations
2433 def __init__(self):
2434 pass
2435
2436 # These methods return arguments that can be used for
2437 # creating a new output tensor
2438 @staticmethod
2439 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002440 assert len(a.shape) == len(b.shape)
2441 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002442
2443 shape = []
2444 for i in range(len(a.shape)):
2445 if a.shape[i] == 1:
2446 shape.append(b.shape[i])
2447 else:
2448 shape.append(a.shape[i])
2449
Kevin Cheng550ccc52021-03-03 11:21:43 -08002450 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002451
2452 @staticmethod
2453 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002454 assert len(a.shape) == len(b.shape)
2455 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002456
2457 shape = []
2458 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002459 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002460 shape.append(a.shape[i])
2461
Kevin Cheng550ccc52021-03-03 11:21:43 -08002462 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002463
2464 @staticmethod
2465 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002466 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002467
2468 @staticmethod
2469 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002470 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2471 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002472
2473 shape = []
2474 for i in range(len(a.shape)):
2475 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2476
Kevin Cheng550ccc52021-03-03 11:21:43 -08002477 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002478
2479 @staticmethod
2480 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002481 assert len(a.shape) == len(b.shape)
2482 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002483
2484 # Do broadcast
2485 shape = []
2486 for i in range(len(a.shape)):
2487 if a.shape[i] == 1:
2488 shape.append(b.shape[i])
2489 else:
2490 shape.append(a.shape[i])
2491
2492 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002493 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002494
2495 @staticmethod
2496 def reduceOp(ser, a, axis):
2497
2498 shape = a.shape.copy()
2499
2500 shape[axis] = 1
2501
Kevin Cheng550ccc52021-03-03 11:21:43 -08002502 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002503
2504 @staticmethod
2505 def argmaxOp(ser, a, axis):
2506 shape = a.shape.copy()
2507 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002508 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002509
2510 @staticmethod
2511 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2512
2513 # IFM: NHWC
2514 # Filter: OHWI
2515 # OFM: NHWC
2516
2517 if len(padding) == 2:
2518 # Expand padding to 4 parameters in the case of transpose_conv2d
2519 # From H,W to T,B,L,R
2520 padding = [padding[0], padding[0], padding[1], padding[1]]
2521
Kevin Cheng550ccc52021-03-03 11:21:43 -08002522 h = (
2523 ifm.shape[1]
2524 - filter.shape[1]
2525 - (filter.shape[1] - 1) * (dilations[0] - 1)
2526 + padding[0]
2527 + padding[1]
2528 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002529
Kevin Cheng550ccc52021-03-03 11:21:43 -08002530 w = (
2531 ifm.shape[2]
2532 - filter.shape[2]
2533 - (filter.shape[2] - 1) * (dilations[1] - 1)
2534 + padding[2]
2535 + padding[3]
2536 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002537
2538 if h <= 0 or w <= 0:
2539 # Invalid test parameters?
2540 h = 0
2541 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002542 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002543
2544 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2545
Kevin Cheng3a478572021-01-22 17:21:02 -08002546 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002547 out_dtype = DType.INT32
2548 elif ifm.dtype == DType.INT16:
2549 out_dtype = DType.INT48
2550 elif ifm.dtype == DType.FLOAT:
2551 out_dtype = DType.FLOAT
2552 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002553 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002554
Kevin Cheng550ccc52021-03-03 11:21:43 -08002555 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002556
2557 @staticmethod
2558 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2559 # IFM: NHWC
2560 # Filter: HWCM
2561 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002562 h = (
2563 ifm.shape[1]
2564 - filter.shape[0]
2565 - (filter.shape[0] - 1) * (dilations[0] - 1)
2566 + padding[0]
2567 + padding[1]
2568 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002569
Kevin Cheng550ccc52021-03-03 11:21:43 -08002570 w = (
2571 ifm.shape[2]
2572 - filter.shape[1]
2573 - (filter.shape[1] - 1) * (dilations[1] - 1)
2574 + padding[2]
2575 + padding[3]
2576 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002577
2578 if h <= 0 or w <= 0:
2579 # Invalid test parameters?
2580 h = 0
2581 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002582 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
2584 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2585
Kevin Cheng3a478572021-01-22 17:21:02 -08002586 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002587 out_dtype = DType.INT32
2588 elif ifm.dtype == DType.INT16:
2589 out_dtype = DType.INT48
2590 elif ifm.dtype == DType.FLOAT:
2591 out_dtype = DType.FLOAT
2592 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002593 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002594
Kevin Cheng550ccc52021-03-03 11:21:43 -08002595 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002596
2597 @staticmethod
2598 def pool2dOp(ser, ifm, kernel, stride, pad):
2599 # input: NHWC
2600 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2601 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2602
2603 if h <= 0 or w <= 0:
2604 # Invalid test parameters?
2605 h = 0
2606 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002607 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002608
2609 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002610 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002611
2612 @staticmethod
2613 def fullyConnectedOp(ser, input, filter):
2614 # input: N, IC
2615 # filter: OC, IC
2616 # output: N, OC
2617
2618 output_shape = [input.shape[0], filter.shape[0]]
2619
Kevin Cheng3a478572021-01-22 17:21:02 -08002620 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002621 out_dtype = DType.INT32
2622 elif input.dtype == DType.INT16:
2623 out_dtype = DType.INT48
2624 elif input.dtype == DType.FLOAT:
2625 out_dtype = DType.FLOAT
2626 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002627 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002628
Kevin Cheng550ccc52021-03-03 11:21:43 -08002629 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002630
2631 @staticmethod
2632 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07002633 # a: N, H, C
2634 # b: N, C, W
2635 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07002636
Kevin Cheng2d60f002021-06-09 14:18:32 -07002637 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002638
Kevin Cheng3a478572021-01-22 17:21:02 -08002639 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002640 out_dtype = DType.INT32
2641 elif a.dtype == DType.INT16:
2642 out_dtype = DType.INT48
2643 elif a.dtype == DType.FLOAT:
2644 out_dtype = DType.FLOAT
2645 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002646 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002647
Kevin Cheng550ccc52021-03-03 11:21:43 -08002648 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002649
2650 @staticmethod
2651 def concatOp(ser, a, b, axis):
2652
2653 output_shape = a.shape.copy()
2654 output_shape[axis] = a.shape[axis] + b.shape[axis]
2655
Kevin Cheng550ccc52021-03-03 11:21:43 -08002656 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002657
2658 @staticmethod
2659 def padOp(ser, a, padding):
2660
2661 output_shape = a.shape.copy()
2662
2663 for i in range(len(output_shape)):
2664 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2665
Kevin Cheng550ccc52021-03-03 11:21:43 -08002666 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002667
2668 @staticmethod
2669 def reshapeOp(ser, a, shape):
2670 output_shape = shape.copy()
2671
2672 totalElements = 1
2673 for i in a.shape:
2674 totalElements *= i
2675
2676 # If there are any -1 elements, figure out what that dimension must be
2677 totalOutputElements = 1
2678 for i in output_shape:
2679 if i != -1:
2680 totalOutputElements *= i
2681
2682 # And fill it in
2683 for i in range(len(output_shape)):
2684 if output_shape[i] == -1:
2685 output_shape[i] = totalElements // totalOutputElements
2686
Kevin Cheng550ccc52021-03-03 11:21:43 -08002687 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002688
2689 @staticmethod
2690 def sliceOp(ser, a, begin, size):
2691
2692 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002693 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002694
2695 @staticmethod
2696 def tileOp(ser, a, multiples):
2697
2698 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002699 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002700
2701 for i in range(len(output_shape)):
2702 output_shape[i] = a.shape[i] * multiples[i]
2703
Kevin Cheng550ccc52021-03-03 11:21:43 -08002704 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002705
2706 @staticmethod
2707 def transposeOp(ser, a, perms):
2708 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002709 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002710
2711 for i in range(len(output_shape)):
2712 output_shape[i] = a.shape[perms[i]]
2713
Kevin Cheng550ccc52021-03-03 11:21:43 -08002714 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002715
2716 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002717 def gatherOp(ser, values, indices):
2718 assert len(values.shape) == 3
2719 assert len(indices.shape) == 2
2720 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002721
Kevin Cheng77d0f762020-11-24 10:26:32 -08002722 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2723
Kevin Cheng550ccc52021-03-03 11:21:43 -08002724 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002725
2726 @staticmethod
2727 def scatterOp(ser, values_in, indices, input):
2728 assert len(values_in.shape) == 3
2729 assert len(indices.shape) == 2
2730 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002731 assert values_in.shape[0] == indices.shape[0] # N
2732 assert input.shape[1] == indices.shape[1] # W
2733 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002734
2735 output_shape = values_in.shape
2736
Kevin Cheng550ccc52021-03-03 11:21:43 -08002737 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002738
2739 @staticmethod
2740 def tableOp(ser, input, table):
2741 # Same shape as the input, but with the type of the table.
Kevin Cheng550ccc52021-03-03 11:21:43 -08002742 return ser.addOutput(input.shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002743
2744 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002745 def resizeOp(
2746 ser,
2747 input,
2748 mode,
2749 stride,
2750 offset,
2751 shift,
2752 stride_fp,
2753 offset_fp,
2754 output_dims,
2755 input_dtype,
2756 output_dtype,
2757 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002758
2759 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2760
Kevin Cheng77d0f762020-11-24 10:26:32 -08002761 if input_dtype == DType.FLOAT:
2762 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002763 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002764 else:
2765 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002766 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002767
Kevin Chengaee1fac2020-11-11 13:54:06 -08002768 if mode == ResizeMode.BILINEAR:
2769 if input_dtype == DType.INT8:
2770 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002771 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002772 elif input_dtype == DType.INT16:
2773 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002774 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002775 elif input_dtype == DType.FLOAT:
2776 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002777 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002778 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002779 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002780
2781 elif mode == ResizeMode.NEAREST:
2782 if input_dtype == DType.INT8:
2783 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002784 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002785 elif input_dtype == DType.INT16:
2786 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002787 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002788 elif input_dtype == DType.FLOAT:
2789 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002790 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002791 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002792 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002793
2794 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002795 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002796
Kevin Cheng550ccc52021-03-03 11:21:43 -08002797 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002798
2799 @staticmethod
2800 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002801 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002802
2803 @staticmethod
2804 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002805 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002806 out_dtype = DType.INT32
2807 elif ifm.dtype == DType.INT16:
2808 out_dtype = DType.INT48
2809 elif ifm.dtype == DType.FLOAT:
2810 out_dtype = DType.FLOAT
2811 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002812 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002813
2814 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002815 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002816
Kevin Cheng550ccc52021-03-03 11:21:43 -08002817 return ser.addOutput(output_shape, out_dtype)