blob: 7731a75decf0edaf613c79e57074c6c799b14c77 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
32
33from enum import IntEnum, Enum, unique
34
Kevin Cheng550ccc52021-03-03 11:21:43 -080035# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
36parent_dir = os.path.dirname(os.path.realpath(__file__))
37sys.path.append(
38 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
39)
Eric Kunzee5e26762020-10-13 16:11:07 -070040import tosa_serializer as ts
41from tosa_serializer import *
42import tosa
43
44# Convenience variables to the flatc-generated types that should be enums, but aren't
45DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080046Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070047ResizeMode = tosa.ResizeMode.ResizeMode()
48
Kevin Cheng550ccc52021-03-03 11:21:43 -080049
Eric Kunzee5e26762020-10-13 16:11:07 -070050class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080051 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
52
Eric Kunzee5e26762020-10-13 16:11:07 -070053 def __init__(self):
54 pass
55
56 @staticmethod
57 def needsQinfo(op, dtype):
Jared Smolens2a76ad22021-03-04 11:18:54 -080058 if dtype == DType.INT8 or dtype == DType.INT16:
Eric Kunzee5e26762020-10-13 16:11:07 -070059 return True
60 return False
61
62 @staticmethod
63 def qgUnary(testGen, op, dtype):
64 qinfo = ts.TosaSerializerQuantInfo()
65 if TosaQuantGen.needsQinfo(op, dtype):
66 qinfo.UnaryQuantInfo(testGen.randInt(), testGen.randInt())
67 else:
68 qinfo.UnaryQuantInfo(0, 0)
69 return qinfo
70
71 @staticmethod
72 def qgConv(testGen, op, dtype):
73 qinfo = ts.TosaSerializerQuantInfo()
74 if TosaQuantGen.needsQinfo(op, dtype):
75 qinfo.ConvQuantInfo(testGen.randInt(), testGen.randInt())
76 else:
77 qinfo.ConvQuantInfo(0, 0)
78 return qinfo
79
80 @staticmethod
81 def qgMatmul(testGen, op, dtype):
82 qinfo = ts.TosaSerializerQuantInfo()
83 if TosaQuantGen.needsQinfo(op, dtype):
84 qinfo.MatMulQuantInfo(testGen.randInt(), testGen.randInt())
85 else:
86 qinfo.MatMulQuantInfo(0, 0)
87 return qinfo
88
89 @staticmethod
90 def qgPad(testGen, op, dtype):
91 qinfo = ts.TosaSerializerQuantInfo()
92 if TosaQuantGen.needsQinfo(op, dtype):
93 qinfo.PadQuantInfo(testGen.randInt())
94 else:
95 qinfo.PadQuantInfo(0)
96 return qinfo
97
98 @staticmethod
99 def computeMultiplierAndShift(scaleFp, scale32):
100 # Derived from computeMultiplierAndShiftTosaScale32
101 # Provide a floating-point scaling factor and the scale32 parameter
102 # to compute the multiplier and shift
103
104 if scale32:
105 scaleBits = 31
106 else:
107 scaleBits = 15
108
109 m, shift = math.frexp(scaleFp)
110
111 if scaleFp < 0.0:
112 m = -m
113
114 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800115 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700116
117 if multiplier == (1 << scaleBits):
118 multiplier = multiplier // 2
119 shift = shift + 1
120
121 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800122 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700123
Kevin Cheng550ccc52021-03-03 11:21:43 -0800124 assert multiplier <= (1 << scaleBits)
125 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700126
127 return multiplier, shift
128
129
Kevin Cheng550ccc52021-03-03 11:21:43 -0800130class TosaTensorGen:
131 """Tensor generators create a shape list for the placeholder and const tensor
132 data operands for the operator. The actual random data is generated separately for each test."""
133
Eric Kunzee5e26762020-10-13 16:11:07 -0700134 def __init__(self):
135 pass
136
137 @staticmethod
138 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800139 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700140 shape = testGen.makeShape(rank)
141
142 shape_list = []
143 for i in range(pl + const):
144 shape_list.append(shape.copy())
145
146 return shape_list
147
148 @staticmethod
149 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800150 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700151
Kevin Cheng550ccc52021-03-03 11:21:43 -0800152 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700153
154 shape = testGen.makeShape(rank)
155
156 # Constrict the batch size?
157 if testGen.args.max_batch_size:
158 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
159
160 shape_list = []
161 for i in range(pl + const):
162 shape_list.append(shape.copy())
163
164 return shape_list
165
166 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800167 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800168 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800169
Kevin Cheng550ccc52021-03-03 11:21:43 -0800170 assert pl == 2
171 assert const == 0
172 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800173
174 values_in_shape = testGen.makeShape(rank)
175
176 # Constrict the batch size?
177 if testGen.args.max_batch_size:
178 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
179
Kevin Cheng550ccc52021-03-03 11:21:43 -0800180 W = testGen.randInt(
181 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
182 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800183 input_shape = [values_in_shape[0], W, values_in_shape[2]]
184
185 shape_list = []
186 shape_list.append(values_in_shape.copy())
187 shape_list.append(input_shape.copy())
188
189 return shape_list
190
191 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 def tgBroadcastFuzz(testGen, op, rank):
193 shape = testGen.makeShape(rank)
194
Kevin Cheng550ccc52021-03-03 11:21:43 -0800195 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700196
197 shape_list = []
198
199 # Choose one of the inputs to broadcast
200 bcast_idx = testGen.randInt(0, pl + const)
201 for i in range(pl + const):
202 shape_bcast = shape.copy()
203
204 # If the chosen input, pick a random index to broadcast
205 if i == bcast_idx:
206 fuzz_idx = testGen.randInt(0, rank)
207 shape_bcast[fuzz_idx] = 1
208
209 shape_list.append(shape_bcast)
210
211 return shape_list
212
213 @staticmethod
214 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800215 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700216
Kevin Cheng550ccc52021-03-03 11:21:43 -0800217 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700218
219 # IFM dimensions are NHWC
220 ifm_shape = testGen.makeShape(rank)
221
222 # Constrict the batch size?
223 if testGen.args.max_batch_size:
224 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
225
226 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800227 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 # Generate a random OFM depth
230 ofm_depth = testGen.makeShape(1)[0]
231
232 # The filter dimensions are OHWI
233 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
234
235 # The bias is OC
236 bias_shape = np.asarray([ofm_depth])
237
238 return [ifm_shape, filter_shape, bias_shape]
239
240 @staticmethod
241 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800242 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700243
Kevin Cheng550ccc52021-03-03 11:21:43 -0800244 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700245
246 # IFM dimensions are NHWC
247 ifm_shape = testGen.makeShape(rank)
248
249 # Constrict the batch size?
250 if testGen.args.max_batch_size:
251 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
252
253 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800254 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700255
256 # Generate a random OFM depth
257 ofm_depth = testGen.makeShape(1)[0]
258
259 # The filter dimensions are OHWI
260 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
261
Kevin Cheng989cb052021-04-28 16:29:44 -0700262 # The bias is OC
263 bias_shape = np.asarray([ofm_depth])
264
265 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700266
267 @staticmethod
268 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800269 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700270
Kevin Cheng550ccc52021-03-03 11:21:43 -0800271 assert rank == 4
272 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700273
274 # IFM dimensions are NHWC
275 ifm_shape = testGen.makeShape(rank)
276
277 # Constrict the batch size?
278 if testGen.args.max_batch_size:
279 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
280
281 # Get the filter height/width from the operator parameters
282 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800283 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700284
285 # Generate a random OFM depth, but don't let it get too big because
286 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800287 filter_m = (
288 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
289 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700290
291 # The filter dimensions are HWCM
292 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
293
294 # The bias is M * C
295 bias_shape = np.asarray([ifm_shape[3] * filter_m])
296
297 return [ifm_shape, filter_shape, bias_shape]
298
299 @staticmethod
300 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800301 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700302
Kevin Cheng550ccc52021-03-03 11:21:43 -0800303 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700304
305 input_shape = testGen.makeShape(rank)
306 filter_oc = testGen.makeShape(1)[0]
307 filter_shape = np.asarray([filter_oc, input_shape[1]])
308
309 bias_shape = np.asarray([filter_oc])
310
311 return [input_shape, filter_shape, bias_shape]
312
313 @staticmethod
314 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800315 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700316
Kevin Cheng550ccc52021-03-03 11:21:43 -0800317 assert rank == 2
318 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
320 a_shape = testGen.makeShape(rank)
321 b_oc = testGen.makeShape(1)[0]
322 b_shape = np.asarray([a_shape[1], b_oc])
323
324 return [a_shape, b_shape]
325
Kevin Cheng550ccc52021-03-03 11:21:43 -0800326
Eric Kunzee5e26762020-10-13 16:11:07 -0700327class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800328 """Argument generators create exhaustive or random lists of attributes for operators that take
329 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
330 tuples where the descriptive_name is appended to the test name and the arglist is expanded
331 as arguments to the operator build function."""
332
Eric Kunzee5e26762020-10-13 16:11:07 -0700333 def __init__(self):
334 pass
335
336 @staticmethod
337 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800338 """A trivial argument generator for operators that don't take any
339 non-tensor arguments"""
340 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700341
342 @staticmethod
343 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800344 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700345 axes = []
346
347 shape = shapeList[0]
348
349 for a in range(0, len(shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800350 axes.append(("axis_{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700351 return axes
352
353 @staticmethod
354 def agConv2D(testGen, opName, shapeList, dtype):
355 arg_list = []
356
357 ifm_shape = shapeList[0]
358 filter_shape = shapeList[1]
359
360 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800361 assert len(ifm_shape) == 4
362 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700363
364 maxStride = testGen.args.max_conv_stride
365 maxPadding = testGen.args.max_conv_padding + 1
366 maxDilation = testGen.args.max_conv_dilation
367
368 # Strides, padding, dilations
369 for stride in range(0, maxStride ** 2):
370 for padding in range(0, (maxPadding) ** 4):
371 for dilation in range(0, maxDilation ** 2):
372
Kevin Cheng550ccc52021-03-03 11:21:43 -0800373 s = [stride // maxStride + 1, stride % maxStride + 1]
374 p = [
375 (padding // (maxPadding * 4)) % maxPadding,
376 (padding // (maxPadding * 2)) % maxPadding,
377 (padding // (maxPadding * 1)) % maxPadding,
378 padding % maxPadding,
379 ]
380 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700381
382 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800383 arg_list.append(
384 (
385 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
386 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
387 ),
388 [s, p, d],
389 )
390 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 return arg_list
392
393 @staticmethod
394 def agTransposeConv2D(testGen, opName, shapeList, dtype):
395 arg_list = []
396
397 ifm_shape = shapeList[0]
398 filter_shape = shapeList[1]
399
400 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800401 assert len(ifm_shape) == 4
402 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700403
404 maxStride = testGen.args.max_conv_stride
405 maxPadding = testGen.args.max_conv_padding + 1
406 maxDilation = testGen.args.max_conv_dilation
407
408 # Strides, padding, dilations
409 for stride in range(0, maxStride ** 2):
410 for out_padding in range(0, (maxPadding) ** 2):
411 for dilation in range(0, maxDilation ** 2):
412
Kevin Cheng550ccc52021-03-03 11:21:43 -0800413 s = [stride // maxStride + 1, stride % maxStride + 1]
414 p = [
415 (out_padding // (maxPadding * 1)) % maxPadding,
416 out_padding % maxPadding,
417 ]
418 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700419
Kevin Cheng550ccc52021-03-03 11:21:43 -0800420 oh = (
421 ifm_shape[1]
422 - filter_shape[1]
423 - (filter_shape[1] - 1) * (d[0] - 1)
424 + 2 * p[0]
425 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700426
Kevin Cheng550ccc52021-03-03 11:21:43 -0800427 ow = (
428 ifm_shape[2]
429 - filter_shape[2]
430 - (filter_shape[2] - 1) * (d[1] - 1)
431 + 2 * p[1]
432 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
434 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800435 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700436
Kevin Cheng550ccc52021-03-03 11:21:43 -0800437 arg_list.append(
438 (
439 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
440 s[0],
441 s[1],
442 p[0],
443 p[1],
444 d[0],
445 d[1],
446 os[0],
447 os[1],
448 os[2],
449 os[3],
450 ),
451 [s, p, d, os],
452 )
453 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700454
455 return arg_list
456
457 @staticmethod
458 def agPad(testGen, opName, shapeList, dtype):
459 arg_list = []
460 rank = len(shapeList[0])
461
462 # Exhaustively test combinations of 0/1 padding on each side of each dimension
463 # This process might need some revision for >1 padding, but use rank**2 as a bitmask
464 # for now
465 for v in range(rank ** 2):
466
467 # Create a flat arraypadding4D
468 paddings = np.zeros((rank * 2), dtype=np.int32)
469
470 # Fill in the 1's
Kevin Cheng550ccc52021-03-03 11:21:43 -0800471 for r in range(rank * 2):
Eric Kunzee5e26762020-10-13 16:11:07 -0700472 if (v >> r) & 1:
473 paddings[r] = 1
474
475 # Reshape back to a 2D array
476 paddings = paddings.reshape((rank, 2))
477
Kevin Cheng550ccc52021-03-03 11:21:43 -0800478 arg_list.append(("pad{0:b}".format(v), [paddings]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700479
480 return arg_list
481
482 @staticmethod
483 def agPooling(testGen, opName, shapeList, dtype):
484 arg_list = []
485
486 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800487 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700488
489 maxStride = testGen.args.max_pooling_stride
490 maxKernel = testGen.args.max_pooling_kernel
491 maxPadding = testGen.args.max_pooling_padding + 1
492
493 for kernel in range(0, maxKernel ** 2):
494 for stride in range(0, maxStride ** 2):
495 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800496 s = [stride // maxStride + 1, stride % maxStride + 1]
497 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
498 p = [
499 (padding // (maxPadding * 4)) % maxPadding,
500 (padding // (maxPadding * 2)) % maxPadding,
501 (padding // (maxPadding * 1)) % maxPadding,
502 padding % maxPadding,
503 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700504
Kevin Cheng550ccc52021-03-03 11:21:43 -0800505 arg_list.append(
506 (
507 "st{}{}_kern{}{}_pad{}{}{}{}".format(
508 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
509 ),
510 [k, s, p],
511 )
512 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700513 return arg_list
514
515 @staticmethod
516 def agCast(testGen, opName, shapeList, inDtype):
517 arg_list = []
518
519 # Enumerate the output types here
520 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800521 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700522 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800523 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700524 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800525 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700526 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800527 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700528 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800529 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700530 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800531 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
533 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800534 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
536 return arg_list
537
538 @staticmethod
539 def agRescale(testGen, opName, shapeList, inDtype):
540 arg_list = []
541
542 # Enumerate the output types here
Kevin Cheng550ccc52021-03-03 11:21:43 -0800543 for dtype in [DType.INT8, DType.INT16, DType.INT32]:
544 for scale32 in [False, True]:
545 for double_round in [False, True]:
546 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700547
548 if inDtype == DType.INT48 and scale32:
549 # Illegal condition. Must be scale32=False
550 continue
551
Kevin Cheng550ccc52021-03-03 11:21:43 -0800552 arg_list.append(
553 (
554 "out{}_sc{}_dr{}_pc{}".format(
555 DTypeNames[dtype],
556 int(scale32),
557 int(double_round),
558 int(per_channel),
559 ),
560 [dtype, scale32, double_round, per_channel],
561 )
562 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
564 return arg_list
565
Kevin Chengaee1fac2020-11-11 13:54:06 -0800566 @staticmethod
567 def agMul(testGen, opName, shapeList, dtype):
568 arg_list = []
569
570 if dtype is DType.INT32:
571 for p in range(testGen.args.num_rand_permutations):
572
573 shift = testGen.randInt(0, 32)
574
Kevin Cheng550ccc52021-03-03 11:21:43 -0800575 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800576 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800577 arg_list.append(("shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800578
579 return arg_list
580
581 @staticmethod
582 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
583 arg_list = []
584
Kevin Cheng550ccc52021-03-03 11:21:43 -0800585 arg_list.append(("roundTrue", [True]))
586 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800587
588 return arg_list
589
Eric Kunzee5e26762020-10-13 16:11:07 -0700590 # Helper function for reshape. Gets some factors of a larger number.
591 @staticmethod
592 def getFactors(val, start=1):
593 factors = []
594
595 for i in range(start, int(np.sqrt(val))):
596 if (val % i) == 0:
597 factors.append(i)
598
599 return factors
600
601 @staticmethod
602 def agReshape(testGen, opName, shapeList, dtype):
603 arg_list = []
604
605 origShape = shapeList[0]
606
607 totalElements = 1
608 for s in origShape:
609 totalElements *= s
610
611 # This code is NOT fast. Fortunately, the numbers are fairly small.
612 factors = TosaArgGen.getFactors(totalElements)
613
614 for p in range(testGen.args.num_rand_permutations):
615 newRank = testGen.randInt(1, 6)
616 newShape = []
Kevin Cheng550ccc52021-03-03 11:21:43 -0800617 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700618 continue
619
620 remainingElements = totalElements
621 shuffledFactors = testGen.rng.permutation(factors)
622 for i in range(newRank):
623 # pick rank-1 factors
624 newShape.append(shuffledFactors[0])
625 remainingElements = remainingElements // shuffledFactors[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800626 shuffledFactors = testGen.rng.permutation(
627 TosaArgGen.getFactors(remainingElements)
628 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700629 newShape.append(remainingElements)
630
631 # Toss in a -1 sometimes
632 minusOne = testGen.randInt(0, newRank * 4)
633 if minusOne < newRank:
634 newShape[minusOne] = -1
635
Kevin Cheng550ccc52021-03-03 11:21:43 -0800636 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700637
638 return arg_list
639
Eric Kunzee5e26762020-10-13 16:11:07 -0700640 @staticmethod
641 def agTranspose(testGen, opName, shapeList, dtype):
642 arg_list = []
643
644 ifm_shape = shapeList[0]
645
646 perms = range(len(ifm_shape))
647 for p in range(testGen.args.num_rand_permutations):
648 perms = np.int32(testGen.rng.permutation(perms)).tolist()
649
650 # Avoid duplicates
651 found = False
652 for name, other_perm in arg_list:
653 if other_perm[0] == perms:
654 found = True
655 break
656
657 if not found:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800658 arg_list.append(("perm{}".format(p), [perms]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700659
660 return arg_list
661
662 @staticmethod
663 def agSlice(testGen, opName, shapeList, dtype):
664 arg_list = []
665
666 ifm_shape = shapeList[0]
667 rank = len(ifm_shape)
668
669 for p in range(testGen.args.num_rand_permutations):
670 begin = []
671 size = []
672
Kevin Cheng550ccc52021-03-03 11:21:43 -0800673 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700674
675 for i in range(rank):
676 if ifm_shape[i] > 1:
677 begin.append(testGen.randInt(0, ifm_shape[i]))
678 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
679
680 # Invalid slice size?
681 if size[i] == 0:
682 valid = False
683 else:
684 begin.append(0)
685 size.append(1)
686
687 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800688 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700689 return arg_list
690
691 @staticmethod
692 def agTile(testGen, opName, shapeList, dtype):
693 arg_list = []
694
695 ifm_shape = shapeList[0]
696 rank = len(ifm_shape)
697
698 for p in range(testGen.args.num_rand_permutations):
699
700 # Pick a few random, but small multiple values
701 # because otherwise this has a tendency to generate
702 # enormous tensors
703 multiples = []
704 for i in range(rank):
705 multiples.append(testGen.randInt(1, 4))
706
Kevin Cheng550ccc52021-03-03 11:21:43 -0800707 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700708
709 return arg_list
710
711 @staticmethod
712 def agResize(testGen, opName, shapeList, dtype):
713 arg_list = []
714
715 ifm_shape = shapeList[0]
716
717 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
718
719 # Exclude illegal {mode, type} configurations. Pick legal output types
720 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800721 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700722 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800723 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700724 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800725 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700726 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800727 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800728 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800729 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700730 else:
731 continue
732
733 for outputDType in outputDTypeList:
734 for perm in range(testGen.args.num_rand_permutations):
735
736 # Randomly generate legal output dimensions and shift
737 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800738 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800739 in_center_h = (ifm_shape[1] - 1) / 2.0
740 in_center_w = (ifm_shape[2] - 1) / 2.0
741 out_center_h = (output_dims[0] - 1) / 2.0
742 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
Kevin Cheng77d0f762020-11-24 10:26:32 -0800744 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
745 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
746 fp_offset_y = in_center_h - fp_stride_y * out_center_h
747 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700748
Kevin Cheng77d0f762020-11-24 10:26:32 -0800749 if outputDType == DType.FLOAT:
750 shift = 0
751 stride = [0, 0]
752 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800753 stride_fp = [fp_stride_y, fp_stride_x]
754 offset_fp = [fp_offset_y, fp_offset_x]
755 arg_list.append(
756 (
757 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
758 m,
759 output_dims[0],
760 output_dims[1],
761 testGen.typeStr(outputDType),
762 stride_fp[0],
763 stride_fp[1],
764 offset_fp[0],
765 offset_fp[1],
766 ),
767 [
768 m,
769 stride,
770 offset,
771 shift,
772 stride_fp,
773 offset_fp,
774 output_dims,
775 dtype,
776 outputDType,
777 ],
778 )
779 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800780 else:
781 shift = 11
782 unit = float(1 << shift)
783 stride_y = int(round(fp_stride_y * unit))
784 stride_x = int(round(fp_stride_x * unit))
785 offset_y = int(round(fp_offset_y * unit))
786 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700787
Kevin Cheng550ccc52021-03-03 11:21:43 -0800788 while (
789 stride_y >= 32768
790 or stride_x >= 32768
791 or offset_y >= 32768
792 or offset_x >= 32768
793 or offset_y < -32768
794 or offset_x < -32768
795 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800796 shift = shift - 1
797 unit = float(1 << shift)
798 stride_y = int(round(fp_stride_y * unit))
799 stride_x = int(round(fp_stride_x * unit))
800 offset_y = int(round(fp_offset_y * unit))
801 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700802
Kevin Cheng550ccc52021-03-03 11:21:43 -0800803 stride = [stride_y, stride_x]
804 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800805
806 stride_fp = [0.0, 0.0]
807 offset_fp = [0.0, 0.0]
808
Kevin Cheng550ccc52021-03-03 11:21:43 -0800809 arg_list.append(
810 (
811 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
812 m,
813 shift,
814 output_dims[0],
815 output_dims[1],
816 testGen.typeStr(outputDType),
817 stride[0],
818 stride[1],
819 offset[0],
820 offset[1],
821 ),
822 [
823 m,
824 stride,
825 offset,
826 shift,
827 stride_fp,
828 offset_fp,
829 output_dims,
830 dtype,
831 outputDType,
832 ],
833 )
834 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700835
836 return arg_list
837
838 def agCondIf(testGen, opName, shapeList, dtype):
839 # CondIf generates the condition values here.
840 # Convert to tensors in the build function, along with the
841 # then and else blocks
842 arg_list = []
843
844 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800845 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700846
847 return arg_list
848
849 def agWhileLoop(testGen, opName, shapeList, dtype):
850 # While loop: 0 iterations, 1, more than 1
851 arg_list = []
852
853 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800854 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700855
856 return arg_list
857
Kevin Cheng550ccc52021-03-03 11:21:43 -0800858
Eric Kunzee5e26762020-10-13 16:11:07 -0700859class TosaTestGen:
860 def __init__(self, args):
861 self.args = args
862 self.basePath = args.output_dir
863 self.random_seed = args.random_seed
864 self.ser = None
865 self.rng = np.random.default_rng(self.random_seed)
866 self.createDynamicOpLists()
867 self.initOpListDefaults()
868 self.quantGen = TosaQuantGen()
869 # Force makeShape to do a specific starting shape
870 self.targetted_shape = None
871
872 def createSerializer(self, opName, testPath):
873 self.testPath = os.path.join(opName, testPath)
874
875 fullPath = os.path.join(self.basePath, self.testPath)
876 os.makedirs(fullPath, exist_ok=True)
877 self.ser = ts.TosaSerializer(fullPath)
878
879 def getSerializer(self):
880 return self.ser
881
882 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800883 with open(
884 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
885 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700886 fd.write(self.ser.serialize())
887
Kevin Cheng550ccc52021-03-03 11:21:43 -0800888 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
889 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700890
891 def getRandTensor(self, shape, dtype):
892 RAND_SHIFT_FACTOR = 0.5
893 RAND_SCALE_FACTOR = 4.0
894
895 if dtype == DType.BOOL:
896 np_dt = np.bool
897 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700898 elif dtype == DType.INT4:
899 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
900 elif dtype == DType.INT8:
901 return np.int32(self.rng.integers(low=-127, high=128, size=shape))
902 elif dtype == DType.INT16:
903 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
904 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800905 return np.int32(
906 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
907 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700908 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800909 return np.int64(
910 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
911 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700912 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800913 return np.float32(
914 self.rng.random(size=shape) - RAND_SHIFT_FACTOR * RAND_SCALE_FACTOR
915 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700916 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800917 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700918
Kevin Cheng989cb052021-04-28 16:29:44 -0700919 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700920 placeholders = []
921
Kevin Cheng989cb052021-04-28 16:29:44 -0700922 assert len(shape_list) == len(dtype_list)
923
924 for idx, shape in enumerate(shape_list):
925 arr = self.getRandTensor(shape, dtype_list[idx])
926 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700927
928 return placeholders
929
Kevin Cheng989cb052021-04-28 16:29:44 -0700930 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700931 consts = []
932
Kevin Cheng989cb052021-04-28 16:29:44 -0700933 assert len(shape_list) == len(dtype_list)
934
935 for idx, shape in enumerate(shape_list):
936 arr = self.getRandTensor(shape, dtype_list[idx])
937 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700938
939 return consts
940
941 def makeShape(self, rank):
942 if self.targetted_shape:
943 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800944 return np.int32(
945 self.rng.integers(
946 low=self.args.tensor_shape_range[0],
947 high=self.args.tensor_shape_range[1],
948 size=rank,
949 )
950 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700951
952 def setTargetShape(self, shape):
953 self.targetted_shape = shape
954
955 def randInt(self, low=0, high=256):
956 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
957
958 def getRandNumberDType(self, dtype):
959 if dtype == DType.FLOAT:
960 return self.rng.random()
961 elif dtype == DType.BOOL:
962 return self.rng.choice([False, True])
963 elif dtype == DType.INT4:
964 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700965 elif dtype == DType.INT8:
966 low, high = (-127, 128)
967 elif dtype == DType.INT16:
968 low, high = (-32768, 32768)
969 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800970 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700971 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800972 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700973 # Special size
974 return np.int64(self.rng.integers(low, high, size=1))[0]
975 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800976 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700977
978 return np.int32(self.rng.integers(low, high, size=1))[0]
979
980 def shapeStr(self, shape):
981
982 sStr = []
983 # Convert to strings
984 for i in shape:
985 sStr.append(str(i))
986
Kevin Cheng550ccc52021-03-03 11:21:43 -0800987 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700988
989 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700990 if isinstance(t, list):
991 assert len(t) >= 2
992 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700993 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700994 if t == DType.BOOL:
995 return "b"
996 elif t == DType.INT4:
997 return "i4"
998 elif t == DType.INT8:
999 return "i8"
1000 elif t == DType.UINT8:
1001 return "u8"
1002 elif t == DType.INT16:
1003 return "i16"
1004 elif t == DType.INT32:
1005 return "i32"
1006 elif t == DType.INT48:
1007 return "i48"
1008 elif t == DType.FLOAT:
1009 return "float"
1010 else:
1011 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001012
1013 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001014 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001015 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001016 return 4
1017 elif t == DType.INT8:
1018 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001019 elif t == DType.UINT8:
1020 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001021 elif t == DType.INT16:
1022 return 16
1023 elif t == DType.INT32:
1024 return 32
1025 elif t == DType.INT48:
1026 return 48
1027 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001028 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001029
1030 # Argument generators
1031 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1032 # Where the string descriptor is used to generate the test name and
1033 # The build_fcn_arg_list is expanded and passed to the operator test
1034 # build function
1035
Kevin Cheng550ccc52021-03-03 11:21:43 -08001036 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 result_tens = OutputShaper.unaryOp(self.ser, a)
1038 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1039 return result_tens
1040
1041 def build_binary_broadcast(self, op, a, b):
1042 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1043 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1044 return result_tens
1045
1046 def build_binary_nonbroadcast(self, op, a, b):
1047 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1048 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1049 return result_tens
1050
Kevin Chengaee1fac2020-11-11 13:54:06 -08001051 def build_arithmetic_right_shift(self, op, a, b, round):
1052 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1053
1054 attr = ts.TosaSerializerAttribute()
1055 attr.ArithmeticRightShiftAttribute(round)
1056
1057 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1058 return result_tens
1059
1060 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001061 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1062
1063 # Special for multiply:
1064 # Force the result to INT32 for INT types
1065 if a.dtype != DType.FLOAT:
1066 result_tens.setDtype(DType.INT32)
1067
Kevin Chengaee1fac2020-11-11 13:54:06 -08001068 attr = ts.TosaSerializerAttribute()
1069 attr.MulAttribute(shift)
1070
1071 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001072 return result_tens
1073
1074 def build_table(self, op, a):
1075 # Constant size, random values
1076 table_arr = self.getRandTensor([513], DType.INT16)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001077 table_tens = self.ser.addConst(table_arr.shape, DType.INT16, table_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001078
1079 result_tens = OutputShaper.tableOp(self.ser, a, table_tens)
1080 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1081
1082 return result_tens
1083
1084 def build_select(self, op, cond, a, b):
1085
1086 # Replace the cond tensor with a boolean tensor since it probably
1087 # has the wrong dtype
Kevin Cheng989cb052021-04-28 16:29:44 -07001088 t = self.buildPlaceholderTensors([cond.shape], [DType.BOOL])
Eric Kunzee5e26762020-10-13 16:11:07 -07001089 cond = t[0]
1090
1091 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1092 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
1093
1094 return result_tens
1095
1096 def build_comparison(self, op, a, b):
1097 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1098 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1099 return result_tens
1100
1101 def build_argmax(self, op, a, axis):
1102 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1103
1104 attr = ts.TosaSerializerAttribute()
1105 attr.AxisAttribute(axis)
1106
1107 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1108 return result_tens
1109
Kevin Cheng550ccc52021-03-03 11:21:43 -08001110 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001111 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1112
1113 attr = ts.TosaSerializerAttribute()
1114 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001115
1116 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1117 return result_tens
1118
1119 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001120 assert len(padding) == 4
1121 result_tens = OutputShaper.conv2dOp(
1122 self.ser, ifm, filter, strides, padding, dilations
1123 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001124
1125 attr = ts.TosaSerializerAttribute()
1126 attr.Conv2dAttribute(padding, strides, dilations)
1127
Kevin Cheng550ccc52021-03-03 11:21:43 -08001128 self.ser.addOperator(
1129 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1130 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001131 return result_tens
1132
Kevin Cheng550ccc52021-03-03 11:21:43 -08001133 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001134 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001135 ):
1136 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001137 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1138
1139 attr = ts.TosaSerializerAttribute()
1140 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1141
Kevin Cheng550ccc52021-03-03 11:21:43 -08001142 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001143 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001144 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001145 return result_tens
1146
Kevin Cheng550ccc52021-03-03 11:21:43 -08001147 def build_depthwise_conv2d(
1148 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1149 ):
1150 result_tens = OutputShaper.depthwiseConv2dOp(
1151 self.ser, ifm, filter, strides, padding, dilations
1152 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001153
1154 attr = ts.TosaSerializerAttribute()
1155 attr.Conv2dAttribute(padding, strides, dilations)
1156
Kevin Cheng550ccc52021-03-03 11:21:43 -08001157 self.ser.addOperator(
1158 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1159 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001160 return result_tens
1161
1162 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1163 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1164
Kevin Cheng550ccc52021-03-03 11:21:43 -08001165 self.ser.addOperator(
1166 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1167 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001168 return result_tens
1169
1170 def build_matmul(self, op, a, b, qinfo):
1171 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1172 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1173 return result_tens
1174
1175 def build_reduce(self, op, a, axis):
1176 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1177
1178 attr = ts.TosaSerializerAttribute()
1179 attr.AxisAttribute(axis)
1180
1181 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1182 return result_tens
1183
1184 def build_clamp(self, op, a):
1185 result_tens = OutputShaper.unaryOp(self.ser, a)
1186
1187 attr = ts.TosaSerializerAttribute()
1188
1189 # Get two random ints
1190 v = [self.randInt(), self.randInt()]
1191
1192 if a.dtype == DType.FLOAT:
1193 attr.ClampAttribute(0, 0, min(v), max(v))
1194 else:
1195 attr.ClampAttribute(min(v), max(v), 0, 0)
1196
1197 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1198 return result_tens
1199
1200 def build_leaky_relu(self, op, a):
1201 result_tens = OutputShaper.unaryOp(self.ser, a)
1202 attr = ts.TosaSerializerAttribute()
1203
1204 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1205
1206 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1207 return result_tens
1208
1209 # Needs an additional type/input
1210 def build_prelu(self, op, a):
1211 result_tens = OutputShaper.unaryOp(self.ser, a)
1212
1213 self.ser.addOperator(op, [a.name], [result_tens.name])
1214 return result_tens
1215
1216 def build_relun(self, op, a):
1217 result_tens = OutputShaper.unaryOp(self.ser, a)
1218
1219 attr = ts.TosaSerializerAttribute()
1220
1221 if a.dtype == DType.FLOAT:
1222 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1223 else:
1224 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1225
1226 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1227 return result_tens
1228
1229 def build_sigmoid(self, op, a):
1230 result_tens = OutputShaper.unaryOp(self.ser, a)
1231 self.ser.addOperator(op, [a.name], [result_tens.name])
1232 return result_tens
1233
1234 def build_tanh(self, op, a):
1235 result_tens = OutputShaper.unaryOp(self.ser, a)
1236 self.ser.addOperator(op, [a.name], [result_tens.name])
1237 return result_tens
1238
1239 def build_concat(self, op, a, b, axis):
1240 result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
1241
1242 attr = ts.TosaSerializerAttribute()
1243 attr.AxisAttribute(axis)
1244
1245 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1246
1247 def build_pad(self, op, a, padding, qinfo):
1248 result_tens = OutputShaper.padOp(self.ser, a, padding)
1249
1250 # Need to turn the padding array into a TOSA tensor here.
1251 # This is one of the few tensor operands that does not get
1252 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001253 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001254
Kevin Cheng550ccc52021-03-03 11:21:43 -08001255 self.ser.addOperator(
1256 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1257 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001258
1259 def build_reshape(self, op, a, newShape):
1260 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1261
1262 attr = ts.TosaSerializerAttribute()
1263 attr.ReshapeAttribute(newShape)
1264
1265 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1266 return result_tens
1267
1268 def build_reverse(self, op, a, axis):
1269 result_tens = OutputShaper.unaryOp(self.ser, a)
1270
1271 attr = ts.TosaSerializerAttribute()
1272 attr.AxisAttribute(axis)
1273
1274 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1275 return result_tens
1276
1277 def build_transpose(self, op, a, perms):
1278 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1279
Kevin Cheng550ccc52021-03-03 11:21:43 -08001280 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001281
1282 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1283 return result_tens
1284
1285 def build_slice(self, op, a, begin, size):
1286 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1287
1288 attr = ts.TosaSerializerAttribute()
1289 attr.SliceAttribute(begin, size)
1290
1291 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1292 return result_tens
1293
1294 def build_tile(self, op, a, multiples):
1295 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1296
1297 attr = ts.TosaSerializerAttribute()
1298 attr.TileAttribute(multiples)
1299
1300 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1301 return result_tens
1302
Kevin Cheng77d0f762020-11-24 10:26:32 -08001303 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001304
1305 # Create a new indicies tensor
1306 # here with data that doesn't exceed the dimensions of the values tensor
1307
Kevin Cheng550ccc52021-03-03 11:21:43 -08001308 K = values.shape[1] # K
1309 W = self.randInt(
1310 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1311 ) # W
1312 indicies_arr = np.int32(
1313 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1314 ) # (N, W)
1315 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001316
Kevin Cheng77d0f762020-11-24 10:26:32 -08001317 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
Kevin Cheng77d0f762020-11-24 10:26:32 -08001319 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001320
1321 return result_tens
1322
Kevin Cheng77d0f762020-11-24 10:26:32 -08001323 def build_scatter(self, op, values_in, input):
1324
1325 # Create a new indicies tensor
1326 # here with data that doesn't exceed the dimensions of the values_in tensor
1327
Kevin Cheng550ccc52021-03-03 11:21:43 -08001328 K = values_in.shape[1] # K
1329 W = input.shape[1] # W
1330 indicies_arr = np.int32(
1331 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1332 ) # (N, W)
1333 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001334
1335 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1336
Kevin Cheng550ccc52021-03-03 11:21:43 -08001337 self.ser.addOperator(
1338 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1339 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001340
1341 return result_tens
1342
Kevin Cheng550ccc52021-03-03 11:21:43 -08001343 def build_resize(
1344 self,
1345 op,
1346 input,
1347 mode,
1348 stride,
1349 offset,
1350 shift,
1351 stride_fp,
1352 offset_fp,
1353 output_dims,
1354 input_dtype,
1355 output_dtype,
1356 ):
1357 result_tens = OutputShaper.resizeOp(
1358 self.ser,
1359 input,
1360 mode,
1361 stride,
1362 offset,
1363 shift,
1364 stride_fp,
1365 offset_fp,
1366 output_dims,
1367 input_dtype,
1368 output_dtype,
1369 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
1371 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001372
Kevin Cheng550ccc52021-03-03 11:21:43 -08001373 attr.ResizeAttribute(
1374 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1375 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001376
1377 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1378 return result_tens
1379
1380 def build_identityn(self, op, val, val2):
1381
Kevin Cheng550ccc52021-03-03 11:21:43 -08001382 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001383 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001384 self.ser.addOperator(
1385 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1386 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001387 return result_tens
1388
1389 def build_placeholder(self, op, val):
1390 # Add an identity op to avoid warning in the reference model
1391 return self.build_unary(Op.IDENTITY, val)
1392
1393 # Type Conversion
1394 def build_cast(self, op, val, out_dtype):
1395 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1396 self.ser.addOperator(op, [val.name], [result_tens.name])
1397 return result_tens
1398
1399 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1400 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1401
1402 if per_channel:
1403 nc = val.shape[-1]
1404 else:
1405 nc = 1
1406
1407 in_type_width = self.typeWidth(val.dtype)
1408 out_type_width = self.typeWidth(out_dtype)
1409
Kevin Cheng3a478572021-01-22 17:21:02 -08001410 if val.dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001411 input_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001412 in_type_width = in_type_width + 1
1413 else:
1414 input_zp = 0
1415
Kevin Cheng3a478572021-01-22 17:21:02 -08001416 if out_dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001417 output_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001418 out_type_width = out_type_width + 1
1419 else:
1420 output_zp = 0
1421
1422 # Calculate scale based on:
1423 # scale = a *(2^output_width)/(2^input_width))
1424
1425 a = np.float32(self.rng.random(size=[nc]))
1426 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1427
1428 if scale32:
1429 pass
1430 # Cap the scaling at 2^15 - 1 for scale16
1431 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1432 else:
1433 # Cap the scaling at 2^15 - 1 for scale16
1434 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1435
Kevin Cheng550ccc52021-03-03 11:21:43 -08001436 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001437
1438 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1439 shift_arr = np.int32(np.zeros(shape=[nc]))
1440
1441 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001442 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1443 scale_arr[i], scale32
1444 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001445 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001446 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
Kevin Cheng550ccc52021-03-03 11:21:43 -08001448 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001449
1450 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001451 attr.RescaleAttribute(
1452 input_zp,
1453 output_zp,
1454 multiplier_arr,
1455 shift_arr,
1456 scale32,
1457 double_round,
1458 per_channel,
1459 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001460
1461 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1462 return result_tens
1463
1464 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1465 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1466 # (except for the generated shap) and the condition. Build Then/Else blocks
1467 # and fill them with const nodes for the body.
1468
1469 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001470 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001471
1472 # Make then/else tensors
1473 out_shape = then_tens.shape
1474 then_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1475 else_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1476
1477 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001478 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001479
1480 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001481 then_block = "THEN_BLOCK"
1482 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001483 attr = ts.TosaSerializerAttribute()
1484 attr.CondIfAttribute(then_block, else_block)
1485
1486 # Finally, build the op and the two blocks
1487 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1488
1489 self.ser.startBasicBlock(then_block)
1490 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001491 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001492 self.ser.addOutputTensor(then_tens)
1493
1494 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001495 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001496 self.ser.addOutputTensor(else_tens)
1497
1498 return result_tens
1499
1500 def build_cond_if_binary(self, op, a, b, cond):
1501 # For cond_if with a binary op in the then/else blocks, take a and b and
1502 # alternately add or subtract them based on the condition
1503
1504 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001505 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001506
Kevin Cheng550ccc52021-03-03 11:21:43 -08001507 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001508 self.ser.currBasicBlock.addOutput(result_tens.name)
1509
1510 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001511 then_block = "THEN_BLOCK"
1512 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001513 attr = ts.TosaSerializerAttribute()
1514 attr.CondIfAttribute(then_block, else_block)
1515
1516 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001517 self.ser.addOperator(
1518 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1519 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001520
1521 self.ser.startBasicBlock(then_block)
1522 self.ser.addInputTensor(a)
1523 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001524 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001525 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1526
1527 self.ser.startBasicBlock(else_block)
1528 self.ser.addInputTensor(a)
1529 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001530 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001531 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1532
1533 return result_tens
1534
1535 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001536 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001537
Kevin Cheng550ccc52021-03-03 11:21:43 -08001538 cond_block = "COND_BLOCK"
1539 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001540
1541 attr = ts.TosaSerializerAttribute()
1542 attr.WhileLoopAttribute(cond_block, body_block)
1543
1544 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001545 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001546 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001547 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001548
1549 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001550 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1551 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1552 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001553
1554 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001555 self.ser.addOperator(
1556 op,
1557 [iter.name, a.name, acc.name],
1558 [iter_out.name, a_out.name, acc_out.name],
1559 attr,
1560 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001561
1562 # COND block (input: iter, output: cond_tens )
1563 self.ser.startBasicBlock(cond_block)
1564 self.ser.addInputTensor(iter)
1565 self.ser.addInputTensor(a)
1566 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001567 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1568 cond_tens = self.ser.addOutput([], DType.BOOL)
1569 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001570
1571 # BODY block (input: a, acc, iter, output: a, acc, iter)
1572 # Note that local intermediate tensors need to be declared here for the outputs
1573 self.ser.startBasicBlock(body_block)
1574 self.ser.addInputTensor(iter)
1575 self.ser.addInputTensor(a)
1576 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001577 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1578 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1579 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001580 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1581 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1582 self.ser.addOutputTensor(iter_body_out)
1583 self.ser.addOutputTensor(a)
1584 self.ser.addOutputTensor(acc_body_out)
1585
1586 return acc_out
1587
Kevin Cheng550ccc52021-03-03 11:21:43 -08001588 def genOpTestList(
1589 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1590 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001591
1592 try:
1593 op = self.TOSA_OP_LIST[opName]
1594 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001595 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001596
1597 # Initialize a new random number generator
1598 self.rng = np.random.default_rng(self.random_seed)
1599
Kevin Cheng550ccc52021-03-03 11:21:43 -08001600 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001601
1602 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001603 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001604
1605 # Test list consists of a tuple of:
1606 # (opName, testNameStr, dtype, shapeList, argumentsList)
1607 testList = []
1608
1609 if not shapeFilter:
1610 shapeFilter = [None]
1611
1612 for r in range(rmin, rmax + 1):
1613
1614 # Filter out the rank?
1615 if rankFilter is not None and r not in rankFilter:
1616 continue
1617
Kevin Cheng550ccc52021-03-03 11:21:43 -08001618 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001619
1620 # Filter tests based on dtype?
1621 if dtypeFilter is not None:
1622 if t not in dtypeFilter:
1623 continue
1624
1625 # Create the placeholder and const tensors
1626 for shape in shapeFilter:
1627 # A None shape chooses a random shape of a given rank
1628
1629 # Filter out by rank
1630 if shape is not None and len(shape) != r:
1631 continue
1632
1633 self.setTargetShape(shape)
1634 shapeList = tgen_fcn(self, op, r)
1635
1636 shapeStr = self.shapeStr(shapeList[0])
1637 typeStr = self.typeStr(t)
1638
1639 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1640 argList = []
1641 if agen_fcn:
1642 argList = agen_fcn(self, opName, shapeList, t)
1643 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001644 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001645
1646 for argStr, args in argList:
1647 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001648 testStr = "{}_{}_{}_{}".format(
1649 opName, shapeStr, typeStr, argStr
1650 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001651 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001652 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001653
1654 testList.append((opName, testStr, t, shapeList, args))
1655
1656 return testList
1657
Kevin Cheng989cb052021-04-28 16:29:44 -07001658 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001659 try:
1660 op = self.TOSA_OP_LIST[opName]
1661 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001662 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001663
1664 # Create a serializer
1665 self.createSerializer(opName, testStr)
1666
Kevin Cheng550ccc52021-03-03 11:21:43 -08001667 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1668 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001669 num_operands = pCount + cCount
1670
1671 if isinstance(dtype_or_dtypeList, list):
1672 dtypeList = dtype_or_dtypeList
1673 else:
1674 dtypeList = [dtype_or_dtypeList] * (num_operands)
1675
1676 assert (
1677 len(shapeList) == num_operands
1678 ), "shapeList length {} must match number of operands {}".format(
1679 len(shapeList), num_operands
1680 )
1681 assert (
1682 len(dtypeList) == num_operands
1683 ), "dtypeList length {} must match number of operands {}".format(
1684 len(dtypeList), num_operands
1685 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001686
1687 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001688 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001689 except KeyError:
1690 qgen = None
1691
1692 # Build the random tensor operands and the test
1693 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001694
1695 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001696 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1697 assert (
1698 pCount == 2 and cCount == 0
1699 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001700
1701 placeholders = []
1702 for idx, shape in enumerate(shapeList[:]):
1703 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001704 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001705 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001706 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001707 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001708 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001709 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1710 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001711 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001712 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001713 arr = self.getRandTensor(shapeList[0], dtypeList[idx])
1714 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001715
1716 tens.extend(placeholders)
1717 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001718 tens.extend(
1719 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1720 )
1721 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001722
1723 if qgen is not None:
Kevin Cheng989cb052021-04-28 16:29:44 -07001724 qinfo = qgen(self, op, dtypeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001725 else:
1726 qinfo = None
1727
1728 try:
1729 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001730 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001731 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001732 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001733 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001734 print(
1735 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1736 build_fcn, tens, testArgs
1737 )
1738 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001739 raise e
1740
1741 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001742 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001743
1744 def createDynamicOpLists(self):
1745
1746 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001747 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001748
1749 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001750 testName = "conv2d_{}x{}".format(k[0], k[1])
1751 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1752 self.TOSA_OP_LIST[testName]["filter"] = k
1753 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001754
Kevin Cheng550ccc52021-03-03 11:21:43 -08001755 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1756 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1757 "depthwise_conv2d_TEMPLATE"
1758 ].copy()
1759 self.TOSA_OP_LIST[testName]["filter"] = k
1760 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001761
Kevin Cheng550ccc52021-03-03 11:21:43 -08001762 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1763 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1764 "transpose_conv2d_TEMPLATE"
1765 ].copy()
1766 self.TOSA_OP_LIST[testName]["filter"] = k
1767 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001768
1769 # Delete any templates after having created any dynamic ops
1770 # This is a two-pass operation because it's bad practice to delete
1771 # keys from dictionaries while iterating
1772 keyList = []
1773 for k in self.TOSA_OP_LIST:
1774 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001775 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001776 keyList.append(k)
1777 continue
1778 except KeyError:
1779 pass
1780
1781 for k in keyList:
1782 del self.TOSA_OP_LIST[k]
1783
1784 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001785 """Fill in default fields for ops if they aren't already specified.
1786 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001787 for op in self.TOSA_OP_LIST:
1788
1789 # Required fields
1790 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001791 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001792 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001793 raise Exception(
1794 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
1795 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001796
1797 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001798 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001799 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001800 raise Exception(
1801 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
1802 op
1803 )
1804 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001805
1806 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001807 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001808 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001809 raise Exception(
1810 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
1811 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001812
1813 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001814 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001815 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001816 raise Exception(
1817 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
1818 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001819
1820 # Put in default rank range, if missing
1821 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001822 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001823 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001824 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07001825
1826 # Tensor operator list
1827 # 'op': op name
1828 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08001829 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
1830 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001831 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
1832 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08001833 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001834
Kevin Cheng550ccc52021-03-03 11:21:43 -08001835 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
1836 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07001837
Kevin Cheng550ccc52021-03-03 11:21:43 -08001838 TYPE_BOOL = [DType.BOOL]
1839 TYPE_FI32 = [DType.FLOAT, DType.INT32]
1840 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
1841 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07001842
Kevin Cheng550ccc52021-03-03 11:21:43 -08001843 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001844
Kevin Cheng989cb052021-04-28 16:29:44 -07001845 TYPE_CONV2D = [
1846 [DType.INT8, DType.INT8, DType.INT32],
1847 [DType.INT16, DType.INT8, DType.INT48],
1848 DType.FLOAT,
1849 ]
1850
Eric Kunzee5e26762020-10-13 16:11:07 -07001851 DEFAULT_RANK_RANGE = (1, 4)
1852
1853 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08001854 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08001855 "argmax": {
1856 "op": Op.ARGMAX,
1857 "operands": (1, 0),
1858 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
1859 "types": TYPE_NARROW_INT_FP,
1860 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001861
1862 "avg_pool2d": {
1863 "op": Op.AVG_POOL2D,
1864 "operands": (1, 0),
1865 "rank": (4, 4),
1866 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
1867 "qgen": TosaQuantGen.qgUnary,
1868 "types": TYPE_NARROW_INT_FP,
1869 },
1870
Eric Kunzee5e26762020-10-13 16:11:07 -07001871 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08001872 "conv2d_TEMPLATE": {
1873 "op": Op.CONV2D,
1874 "operands": (1, 2),
1875 "rank": (4, 4),
1876 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
1877 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07001878 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001879 "template": True,
1880 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001881
1882 # Conv3d TBD
1883
Eric Kunzee5e26762020-10-13 16:11:07 -07001884 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08001885 "depthwise_conv2d_TEMPLATE": {
1886 "op": Op.DEPTHWISE_CONV2D,
1887 "operands": (1, 2),
1888 "filter": [1, 1],
1889 "rank": (4, 4),
1890 "build_fcn": (
1891 build_depthwise_conv2d,
1892 TosaTensorGen.tgDepthwiseConv2D,
1893 TosaArgGen.agConv2D,
1894 ),
1895 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07001896 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001897 "template": True,
1898 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001899
1900 "fully_connected": {
1901 "op": Op.FULLY_CONNECTED,
1902 "operands": (1, 2),
1903 "rank": (2, 2),
1904 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
1905 "qgen": TosaQuantGen.qgConv,
1906 "types": TYPE_CONV2D,
1907 },
1908
1909 "matmul": {
1910 "op": Op.MATMUL,
1911 "operands": (2, 0),
1912 "rank": (2, 2),
1913 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
1914 "qgen": TosaQuantGen.qgMatmul,
1915 "types": TYPE_NARROW_INT_FP,
1916 },
1917
1918 "max_pool2d": {
1919 "op": Op.MAX_POOL2D,
1920 "operands": (1, 0),
1921 "rank": (4, 4),
1922 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
1923 "types": TYPE_NARROW_INT_FP,
1924 },
1925
Eric Kunzee5e26762020-10-13 16:11:07 -07001926 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08001927 "transpose_conv2d_TEMPLATE": {
1928 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07001929 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08001930 "rank": (4, 4),
1931 "build_fcn": (
1932 build_transpose_conv2d,
1933 TosaTensorGen.tgTransposeConv2D,
1934 TosaArgGen.agTransposeConv2D,
1935 ),
1936 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07001937 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001938 "template": True,
1939 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001940
Eric Kunzee5e26762020-10-13 16:11:07 -07001941 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08001942 "clamp": {
1943 "op": Op.CLAMP,
1944 "operands": (1, 0),
1945 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
1946 "types": TYPE_NARROW_INT_FP,
1947 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001948
Kevin Cheng550ccc52021-03-03 11:21:43 -08001949 "relun": {
1950 "op": Op.RELUN,
1951 "operands": (1, 0),
1952 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
1953 "types": TYPE_FI32,
1954 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001955
Kevin Cheng550ccc52021-03-03 11:21:43 -08001956 "sigmoid": {
1957 "op": Op.SIGMOID,
1958 "operands": (1, 0),
1959 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
1960 "types": TYPE_FP,
1961 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001962
Kevin Cheng550ccc52021-03-03 11:21:43 -08001963 "tanh": {
1964 "op": Op.TANH,
1965 "operands": (1, 0),
1966 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
1967 "types": TYPE_FP,
1968 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001969
1970 # Elementwise Binary Operators
1971 "add": {
1972 "op": Op.ADD,
1973 "operands": (2, 0),
1974 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
1975 "types": TYPE_FI32,
1976 },
1977
1978 "arithmetic_right_shift": {
1979 "op": Op.ARITHMETIC_RIGHT_SHIFT,
1980 "operands": (2, 0),
1981 "build_fcn": (
1982 build_arithmetic_right_shift,
1983 TosaTensorGen.tgBroadcastFuzz,
1984 TosaArgGen.agArithmeticRightShift,
1985 ),
1986 "types": TYPE_INT,
1987 },
1988
1989 "bitwise_and": {
1990 "op": Op.BITWISE_AND,
1991 "operands": (2, 0),
1992 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
1993 "types": TYPE_INT,
1994 },
1995
1996 "bitwise_or": {
1997 "op": Op.BITWISE_OR,
1998 "operands": (2, 0),
1999 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2000 "types": TYPE_INT,
2001 },
2002
2003 "bitwise_xor": {
2004 "op": Op.BITWISE_XOR,
2005 "operands": (2, 0),
2006 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2007 "types": TYPE_INT,
2008 },
2009
2010 "logical_and": {
2011 "op": Op.LOGICAL_AND,
2012 "operands": (2, 0),
2013 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2014 "types": TYPE_BOOL,
2015 },
2016
2017 "logical_left_shift": {
2018 "op": Op.LOGICAL_LEFT_SHIFT,
2019 "operands": (2, 0),
2020 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2021 "types": TYPE_INT,
2022 },
2023
2024 "logical_right_shift": {
2025 "op": Op.LOGICAL_RIGHT_SHIFT,
2026 "operands": (2, 0),
2027 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2028 "types": TYPE_INT,
2029 },
2030
2031 "logical_or": {
2032 "op": Op.LOGICAL_OR,
2033 "operands": (2, 0),
2034 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2035 "types": TYPE_BOOL,
2036 },
2037
2038 "logical_xor": {
2039 "op": Op.LOGICAL_XOR,
2040 "operands": (2, 0),
2041 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2042 "types": TYPE_BOOL,
2043 },
2044
2045 "maximum": {
2046 "op": Op.MAXIMUM,
2047 "operands": (2, 0),
2048 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2049 "types": TYPE_FI32,
2050 },
2051
2052 "minimum": {
2053 "op": Op.MINIMUM,
2054 "operands": (2, 0),
2055 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2056 "types": TYPE_FI32,
2057 },
2058
2059 "mul": {
2060 "op": Op.MUL,
2061 "operands": (2, 0),
2062 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2063 "types": TYPE_INT_FP,
2064 },
2065
2066 "pow": {
2067 "op": Op.POW,
2068 "operands": (2, 0),
2069 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2070 "types": TYPE_FP,
2071 },
2072
2073 "sub": {
2074 "op": Op.SUB,
2075 "operands": (2, 0),
2076 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2077 "types": TYPE_FI32,
2078 },
2079
2080 "table": {
2081 "op": Op.TABLE,
2082 # Use the automatic generation functions to create the input array
2083 # but create the table tensor in the build function, as it may be
2084 # a different type from the input
2085 "operands": (1, 0),
2086 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
2087 "types": [DType.INT16],
2088 },
2089
2090 # Elementwise Unary operators
2091 "abs": {
2092 "op": Op.ABS,
2093 "operands": (1, 0),
2094 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2095 "types": TYPE_FI32,
2096 },
2097
2098 "bitwise_not": {
2099 "op": Op.BITWISE_NOT,
2100 "operands": (1, 0),
2101 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2102 "types": TYPE_INT,
2103 },
2104
2105 "ceil": {
2106 "op": Op.CEIL,
2107 "operands": (1, 0),
2108 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2109 "types": TYPE_FP,
2110 },
2111
2112 "clz": {
2113 "op": Op.CLZ,
2114 "operands": (1, 0),
2115 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2116 "types": [DType.INT32],
2117 },
2118
2119 "exp": {
2120 "op": Op.EXP,
2121 "operands": (1, 0),
2122 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2123 "types": TYPE_FP,
2124 },
2125
2126 "floor": {
2127 "op": Op.FLOOR,
2128 "operands": (1, 0),
2129 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2130 "types": TYPE_FP,
2131 },
2132
2133 "log": {
2134 "op": Op.LOG,
2135 "operands": (1, 0),
2136 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2137 "types": TYPE_FP,
2138 },
2139
2140 "logical_not": {
2141 "op": Op.LOGICAL_NOT,
2142 "operands": (1, 0),
2143 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2144 "types": TYPE_BOOL,
2145 },
2146
2147 "negate": {
2148 "op": Op.NEGATE,
2149 "operands": (1, 0),
2150 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2151 "qgen": TosaQuantGen.qgUnary,
2152 "types": TYPE_INT_FP,
2153 },
2154
2155 "reciprocal": {
2156 "op": Op.RECIPROCAL,
2157 "operands": (1, 0),
2158 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2159 "types": TYPE_FP,
2160 },
2161
2162 "rsqrt": {
2163 "op": Op.RSQRT,
2164 "operands": (1, 0),
2165 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2166 "types": TYPE_FP,
2167 },
2168
2169 # Elementwise Ternary operators
2170 "select": {
2171 "op": Op.SELECT,
2172 "operands": (3, 0),
2173 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2174 "types": TYPE_FIB,
2175 },
2176
2177 # Comparison operators
2178 "equal": {
2179 "op": Op.EQUAL,
2180 "operands": (2, 0),
2181 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2182 "types": TYPE_FI32,
2183 },
2184
2185 "greater_equal": {
2186 "op": Op.GREATER_EQUAL,
2187 "operands": (2, 0),
2188 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2189 "types": TYPE_FI32,
2190 },
2191
2192 "greater": {
2193 "op": Op.GREATER,
2194 "operands": (2, 0),
2195 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2196 "types": TYPE_FI32,
2197 },
2198
2199 # Reduction operators
2200 "reduce_all": {
2201 "op": Op.REDUCE_ALL,
2202 "operands": (1, 0),
2203 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2204 "types": TYPE_BOOL,
2205 },
2206
2207 "reduce_any": {
2208 "op": Op.REDUCE_ANY,
2209 "operands": (1, 0),
2210 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2211 "types": TYPE_BOOL,
2212 },
2213
2214 "reduce_max": {
2215 "op": Op.REDUCE_MAX,
2216 "operands": (1, 0),
2217 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2218 "types": TYPE_INT_FP,
2219 },
2220
2221 "reduce_min": {
2222 "op": Op.REDUCE_MAX,
2223 "operands": (1, 0),
2224 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2225 "types": TYPE_INT_FP,
2226 },
2227
2228 "reduce_product": {
2229 "op": Op.REDUCE_PRODUCT,
2230 "operands": (1, 0),
2231 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2232 "types": TYPE_FP,
2233 },
2234
2235 "reduce_sum": {
2236 "op": Op.REDUCE_SUM,
2237 "operands": (1, 0),
2238 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2239 "types": TYPE_FI32,
2240 },
2241
Eric Kunzee5e26762020-10-13 16:11:07 -07002242 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002243 "concat": {
2244 "op": Op.CONCAT,
2245 "operands": (2, 0),
2246 "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2247 "types": TYPE_FIB,
2248 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002249
Kevin Cheng550ccc52021-03-03 11:21:43 -08002250 "pad": {
2251 "op": Op.PAD,
2252 "operands": (1, 0),
2253 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2254 "qgen": TosaQuantGen.qgPad,
2255 "types": TYPE_FIB,
2256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002257
Kevin Cheng550ccc52021-03-03 11:21:43 -08002258 "reshape": {
2259 "op": Op.RESHAPE,
2260 "operands": (1, 0),
2261 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2262 "types": TYPE_FIB,
2263 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002264
Kevin Cheng550ccc52021-03-03 11:21:43 -08002265 "reverse": {
2266 "op": Op.REVERSE,
2267 "operands": (1, 0),
2268 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2269 "types": TYPE_FIB,
2270 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002271
Kevin Cheng550ccc52021-03-03 11:21:43 -08002272 "slice": {
2273 "op": Op.SLICE,
2274 "operands": (1, 0),
2275 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2276 "types": TYPE_FIB,
2277 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002278
Kevin Cheng550ccc52021-03-03 11:21:43 -08002279 "tile": {
2280 "op": Op.TILE,
2281 "operands": (1, 0),
2282 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2283 "types": TYPE_FIB,
2284 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002285
Kevin Cheng550ccc52021-03-03 11:21:43 -08002286 "transpose": {
2287 "op": Op.TRANSPOSE,
2288 "operands": (1, 0),
2289 "rank": (2, 4), # Do not allow tranpose on rank=1
2290 "build_fcn": (
2291 build_transpose,
2292 TosaTensorGen.tgBasic,
2293 TosaArgGen.agTranspose,
2294 ),
2295 "types": TYPE_FIB,
2296 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002297
2298 # Data nodes
2299 "const": {
2300 "op": Op.CONST,
2301 "operands": (1, 0),
2302 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2303 "types": TYPE_FIB,
2304 },
2305
2306 "identity": {
2307 "op": Op.IDENTITY,
2308 "operands": (1, 0),
2309 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2310 "types": TYPE_FIB,
2311 },
2312
2313 "identityn": {
2314 "op": Op.IDENTITYN,
2315 "operands": (2, 0),
2316 "build_fcn": (build_identityn, TosaTensorGen.tgBasic, None),
2317 "types": TYPE_FIB,
2318 },
2319
2320 "placeholder": {
2321 "op": Op.PLACEHOLDER,
2322 "operands": (1, 0),
2323 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2324 "types": TYPE_FIB,
2325 },
2326
Eric Kunzee5e26762020-10-13 16:11:07 -07002327 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002328 "gather": {
2329 "op": Op.GATHER,
2330 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2331 "operands": (1, 0),
2332 "rank": (3, 3),
2333 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2334 "types": TYPE_INT_FP,
2335 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002336
Kevin Cheng550ccc52021-03-03 11:21:43 -08002337 "scatter": {
2338 "op": Op.SCATTER,
2339 # Only specify 'values_in' tensor here.
2340 #'indices' and 'input' are generated in op building stage
2341 "operands": (2, 0),
2342 "rank": (3, 3),
2343 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2344 "types": TYPE_INT_FP,
2345 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002346
Eric Kunzee5e26762020-10-13 16:11:07 -07002347 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002348 "resize": {
2349 "op": Op.RESIZE,
2350 "operands": (1, 0),
2351 "rank": (4, 4),
2352 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2353 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2354 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002355
Eric Kunzee5e26762020-10-13 16:11:07 -07002356 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002357 "cast": {
2358 "op": Op.CAST,
2359 "operands": (1, 0),
2360 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2361 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2362 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002363
Kevin Cheng550ccc52021-03-03 11:21:43 -08002364 "rescale": {
2365 "op": Op.RESCALE,
2366 "operands": (1, 0),
2367 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
2368 "types": [DType.INT8, DType.INT16, DType.INT32, DType.INT48],
2369 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002370
Eric Kunzee5e26762020-10-13 16:11:07 -07002371 # Custom
2372 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002373
2374
2375 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002376 # Two varients of cond_if, one that generates one of two constant tensors (no
2377 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2378 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002379 "cond_if_const": {
2380 "op": Op.COND_IF,
2381 "operands": (0, 2),
2382 "build_fcn": (
2383 build_cond_if_const,
2384 TosaTensorGen.tgBasic,
2385 TosaArgGen.agCondIf,
2386 ),
2387 "types": [DType.BOOL],
2388 },
2389 "cond_if_binary": {
2390 "op": Op.COND_IF,
2391 "operands": (2, 0),
2392 "build_fcn": (
2393 build_cond_if_binary,
2394 TosaTensorGen.tgBasic,
2395 TosaArgGen.agCondIf,
2396 ),
2397 "types": TYPE_FI32,
2398 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002399 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002400 "while_loop": {
2401 "op": Op.WHILE_LOOP,
2402 "operands": (0, 1),
2403 "build_fcn": (
2404 build_while_loop,
2405 TosaTensorGen.tgBasic,
2406 TosaArgGen.agWhileLoop,
2407 ),
2408 "types": [DType.INT32],
2409 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002410 }
2411
Kevin Cheng550ccc52021-03-03 11:21:43 -08002412
Eric Kunzee5e26762020-10-13 16:11:07 -07002413class OutputShaper:
2414 # Methods in this class compute the expected output shape and datatype
2415 # for common classes of operations
2416 def __init__(self):
2417 pass
2418
2419 # These methods return arguments that can be used for
2420 # creating a new output tensor
2421 @staticmethod
2422 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002423 assert len(a.shape) == len(b.shape)
2424 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002425
2426 shape = []
2427 for i in range(len(a.shape)):
2428 if a.shape[i] == 1:
2429 shape.append(b.shape[i])
2430 else:
2431 shape.append(a.shape[i])
2432
Kevin Cheng550ccc52021-03-03 11:21:43 -08002433 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002434
2435 @staticmethod
2436 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002437 assert len(a.shape) == len(b.shape)
2438 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002439
2440 shape = []
2441 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002442 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002443 shape.append(a.shape[i])
2444
Kevin Cheng550ccc52021-03-03 11:21:43 -08002445 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002446
2447 @staticmethod
2448 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002449 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002450
2451 @staticmethod
2452 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002453 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2454 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002455
2456 shape = []
2457 for i in range(len(a.shape)):
2458 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2459
Kevin Cheng550ccc52021-03-03 11:21:43 -08002460 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002461
2462 @staticmethod
2463 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002464 assert len(a.shape) == len(b.shape)
2465 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002466
2467 # Do broadcast
2468 shape = []
2469 for i in range(len(a.shape)):
2470 if a.shape[i] == 1:
2471 shape.append(b.shape[i])
2472 else:
2473 shape.append(a.shape[i])
2474
2475 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002476 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002477
2478 @staticmethod
2479 def reduceOp(ser, a, axis):
2480
2481 shape = a.shape.copy()
2482
2483 shape[axis] = 1
2484
Kevin Cheng550ccc52021-03-03 11:21:43 -08002485 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002486
2487 @staticmethod
2488 def argmaxOp(ser, a, axis):
2489 shape = a.shape.copy()
2490 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002491 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002492
2493 @staticmethod
2494 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2495
2496 # IFM: NHWC
2497 # Filter: OHWI
2498 # OFM: NHWC
2499
2500 if len(padding) == 2:
2501 # Expand padding to 4 parameters in the case of transpose_conv2d
2502 # From H,W to T,B,L,R
2503 padding = [padding[0], padding[0], padding[1], padding[1]]
2504
Kevin Cheng550ccc52021-03-03 11:21:43 -08002505 h = (
2506 ifm.shape[1]
2507 - filter.shape[1]
2508 - (filter.shape[1] - 1) * (dilations[0] - 1)
2509 + padding[0]
2510 + padding[1]
2511 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002512
Kevin Cheng550ccc52021-03-03 11:21:43 -08002513 w = (
2514 ifm.shape[2]
2515 - filter.shape[2]
2516 - (filter.shape[2] - 1) * (dilations[1] - 1)
2517 + padding[2]
2518 + padding[3]
2519 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002520
2521 if h <= 0 or w <= 0:
2522 # Invalid test parameters?
2523 h = 0
2524 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002525 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002526
2527 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2528
Kevin Cheng3a478572021-01-22 17:21:02 -08002529 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002530 out_dtype = DType.INT32
2531 elif ifm.dtype == DType.INT16:
2532 out_dtype = DType.INT48
2533 elif ifm.dtype == DType.FLOAT:
2534 out_dtype = DType.FLOAT
2535 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002536 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002537
Kevin Cheng550ccc52021-03-03 11:21:43 -08002538 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002539
2540 @staticmethod
2541 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2542 # IFM: NHWC
2543 # Filter: HWCM
2544 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002545 h = (
2546 ifm.shape[1]
2547 - filter.shape[0]
2548 - (filter.shape[0] - 1) * (dilations[0] - 1)
2549 + padding[0]
2550 + padding[1]
2551 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002552
Kevin Cheng550ccc52021-03-03 11:21:43 -08002553 w = (
2554 ifm.shape[2]
2555 - filter.shape[1]
2556 - (filter.shape[1] - 1) * (dilations[1] - 1)
2557 + padding[2]
2558 + padding[3]
2559 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002560
2561 if h <= 0 or w <= 0:
2562 # Invalid test parameters?
2563 h = 0
2564 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002566
2567 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2568
Kevin Cheng3a478572021-01-22 17:21:02 -08002569 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002570 out_dtype = DType.INT32
2571 elif ifm.dtype == DType.INT16:
2572 out_dtype = DType.INT48
2573 elif ifm.dtype == DType.FLOAT:
2574 out_dtype = DType.FLOAT
2575 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002576 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002577
Kevin Cheng550ccc52021-03-03 11:21:43 -08002578 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002579
2580 @staticmethod
2581 def pool2dOp(ser, ifm, kernel, stride, pad):
2582 # input: NHWC
2583 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2584 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2585
2586 if h <= 0 or w <= 0:
2587 # Invalid test parameters?
2588 h = 0
2589 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002590 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002591
2592 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002593 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002594
2595 @staticmethod
2596 def fullyConnectedOp(ser, input, filter):
2597 # input: N, IC
2598 # filter: OC, IC
2599 # output: N, OC
2600
2601 output_shape = [input.shape[0], filter.shape[0]]
2602
Kevin Cheng3a478572021-01-22 17:21:02 -08002603 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002604 out_dtype = DType.INT32
2605 elif input.dtype == DType.INT16:
2606 out_dtype = DType.INT48
2607 elif input.dtype == DType.FLOAT:
2608 out_dtype = DType.FLOAT
2609 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002610 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002611
Kevin Cheng550ccc52021-03-03 11:21:43 -08002612 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002613
2614 @staticmethod
2615 def matmulOp(ser, a, b):
2616 # a: M, K
2617 # b: K, N
2618 # out: M, N
2619
2620 output_shape = [a.shape[0], b.shape[1]]
2621
Kevin Cheng3a478572021-01-22 17:21:02 -08002622 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002623 out_dtype = DType.INT32
2624 elif a.dtype == DType.INT16:
2625 out_dtype = DType.INT48
2626 elif a.dtype == DType.FLOAT:
2627 out_dtype = DType.FLOAT
2628 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002629 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002630
Kevin Cheng550ccc52021-03-03 11:21:43 -08002631 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002632
2633 @staticmethod
2634 def concatOp(ser, a, b, axis):
2635
2636 output_shape = a.shape.copy()
2637 output_shape[axis] = a.shape[axis] + b.shape[axis]
2638
Kevin Cheng550ccc52021-03-03 11:21:43 -08002639 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002640
2641 @staticmethod
2642 def padOp(ser, a, padding):
2643
2644 output_shape = a.shape.copy()
2645
2646 for i in range(len(output_shape)):
2647 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2648
Kevin Cheng550ccc52021-03-03 11:21:43 -08002649 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002650
2651 @staticmethod
2652 def reshapeOp(ser, a, shape):
2653 output_shape = shape.copy()
2654
2655 totalElements = 1
2656 for i in a.shape:
2657 totalElements *= i
2658
2659 # If there are any -1 elements, figure out what that dimension must be
2660 totalOutputElements = 1
2661 for i in output_shape:
2662 if i != -1:
2663 totalOutputElements *= i
2664
2665 # And fill it in
2666 for i in range(len(output_shape)):
2667 if output_shape[i] == -1:
2668 output_shape[i] = totalElements // totalOutputElements
2669
Kevin Cheng550ccc52021-03-03 11:21:43 -08002670 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002671
2672 @staticmethod
2673 def sliceOp(ser, a, begin, size):
2674
2675 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002676 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002677
2678 @staticmethod
2679 def tileOp(ser, a, multiples):
2680
2681 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002682 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002683
2684 for i in range(len(output_shape)):
2685 output_shape[i] = a.shape[i] * multiples[i]
2686
Kevin Cheng550ccc52021-03-03 11:21:43 -08002687 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002688
2689 @staticmethod
2690 def transposeOp(ser, a, perms):
2691 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002692 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002693
2694 for i in range(len(output_shape)):
2695 output_shape[i] = a.shape[perms[i]]
2696
Kevin Cheng550ccc52021-03-03 11:21:43 -08002697 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002698
2699 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002700 def gatherOp(ser, values, indices):
2701 assert len(values.shape) == 3
2702 assert len(indices.shape) == 2
2703 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002704
Kevin Cheng77d0f762020-11-24 10:26:32 -08002705 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2706
Kevin Cheng550ccc52021-03-03 11:21:43 -08002707 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002708
2709 @staticmethod
2710 def scatterOp(ser, values_in, indices, input):
2711 assert len(values_in.shape) == 3
2712 assert len(indices.shape) == 2
2713 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002714 assert values_in.shape[0] == indices.shape[0] # N
2715 assert input.shape[1] == indices.shape[1] # W
2716 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002717
2718 output_shape = values_in.shape
2719
Kevin Cheng550ccc52021-03-03 11:21:43 -08002720 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002721
2722 @staticmethod
2723 def tableOp(ser, input, table):
2724 # Same shape as the input, but with the type of the table.
Kevin Cheng550ccc52021-03-03 11:21:43 -08002725 return ser.addOutput(input.shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002726
2727 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002728 def resizeOp(
2729 ser,
2730 input,
2731 mode,
2732 stride,
2733 offset,
2734 shift,
2735 stride_fp,
2736 offset_fp,
2737 output_dims,
2738 input_dtype,
2739 output_dtype,
2740 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002741
2742 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2743
Kevin Cheng77d0f762020-11-24 10:26:32 -08002744 if input_dtype == DType.FLOAT:
2745 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002746 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002747 else:
2748 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002749 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002750
Kevin Chengaee1fac2020-11-11 13:54:06 -08002751 if mode == ResizeMode.BILINEAR:
2752 if input_dtype == DType.INT8:
2753 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002754 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002755 elif input_dtype == DType.INT16:
2756 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002757 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002758 elif input_dtype == DType.FLOAT:
2759 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002760 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002761 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002762 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002763
2764 elif mode == ResizeMode.NEAREST:
2765 if input_dtype == DType.INT8:
2766 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002767 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002768 elif input_dtype == DType.INT16:
2769 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002770 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002771 elif input_dtype == DType.FLOAT:
2772 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002773 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002774 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002775 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002776
2777 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002778 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002779
Kevin Cheng550ccc52021-03-03 11:21:43 -08002780 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002781
2782 @staticmethod
2783 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002784 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002785
2786 @staticmethod
2787 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002788 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002789 out_dtype = DType.INT32
2790 elif ifm.dtype == DType.INT16:
2791 out_dtype = DType.INT48
2792 elif ifm.dtype == DType.FLOAT:
2793 out_dtype = DType.FLOAT
2794 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002795 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002796
2797 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002798 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002799
Kevin Cheng550ccc52021-03-03 11:21:43 -08002800 return ser.addOutput(output_shape, out_dtype)